From 8b731da5ba40041aaf0667045e1f679663aec69e Mon Sep 17 00:00:00 2001
From: He Sichao <1310722434@qq.com>
Date: Tue, 27 Jun 2023 15:46:32 +0800
Subject: [PATCH 001/326] Add new test case for ProbDist
---
.../_src/connect/tests/test_random_conn.py | 218 +++++++++---------
1 file changed, 114 insertions(+), 104 deletions(-)
diff --git a/brainpy/_src/connect/tests/test_random_conn.py b/brainpy/_src/connect/tests/test_random_conn.py
index d063b2c9d..de45a5ff0 100644
--- a/brainpy/_src/connect/tests/test_random_conn.py
+++ b/brainpy/_src/connect/tests/test_random_conn.py
@@ -8,149 +8,138 @@
class TestFixedProb(unittest.TestCase):
- def test_size_consistent(self):
- conn1 = bp.connect.FixedProb(prob=0.1, seed=123)
- conn1(pre_size=(10, 20), post_size=(10, 20))
- pre_ids, post_ids, pre2post = conn1.require('pre_ids', 'post_ids', 'pre2post')
- self.assertTrue(len(pre_ids) == len(post_ids))
- self.assertTrue(len(pre_ids) == len(pre2post[0]))
+ def test_size_consistent(self):
+ conn1 = bp.connect.FixedProb(prob=0.1, seed=123)
+ conn1(pre_size=(10, 20), post_size=(10, 20))
+ pre_ids, post_ids, pre2post = conn1.require('pre_ids', 'post_ids', 'pre2post')
+ self.assertTrue(len(pre_ids) == len(post_ids))
+ self.assertTrue(len(pre_ids) == len(pre2post[0]))
- def test_require_method(self):
- conn2 = bp.connect.FixedProb(prob=0.1, seed=123)
- conn2(pre_size=(10, 20), post_size=(10, 20))
- mat = conn2.require(bp.connect.CONN_MAT)
- self.assertTrue(mat.shape == (200, 200))
+ def test_require_method(self):
+ conn2 = bp.connect.FixedProb(prob=0.1, seed=123)
+ conn2(pre_size=(10, 20), post_size=(10, 20))
+ mat = conn2.require(bp.connect.CONN_MAT)
+ self.assertTrue(mat.shape == (200, 200))
- mat = conn2(100, 1000).require(bp.connect.CONN_MAT)
- self.assertTrue(mat.shape == (100, 1000))
+ mat = conn2(100, 1000).require(bp.connect.CONN_MAT)
+ self.assertTrue(mat.shape == (100, 1000))
- mat = conn2.require(10, 20, bp.connect.CONN_MAT)
- self.assertTrue(mat.shape == (10, 20))
+ mat = conn2.require(10, 20, bp.connect.CONN_MAT)
+ self.assertTrue(mat.shape == (10, 20))
def test_random_fix_pre1():
- for num in [0.4, 20]:
- conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
- mat1 = conn1.require(bp.connect.CONN_MAT)
+ for num in [0.4, 20]:
+ conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
+ mat1 = conn1.require(bp.connect.CONN_MAT)
- conn2 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
- mat2 = conn2.require(bp.connect.CONN_MAT)
+ conn2 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
+ mat2 = conn2.require(bp.connect.CONN_MAT)
- print()
- print(f'num = {num}')
- print('conn_mat 1\n', mat1)
- print(mat1.sum())
- print('conn_mat 2\n', mat2)
- print(mat2.sum())
+ print()
+ print(f'num = {num}')
+ print('conn_mat 1\n', mat1)
+ print(mat1.sum())
+ print('conn_mat 2\n', mat2)
+ print(mat2.sum())
- assert bp.math.array_equal(mat1, mat2)
+ assert bp.math.array_equal(mat1, mat2)
def test_random_fix_pre2():
- for num in [0.5, 3]:
- conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=5, post_size=4)
- mat1 = conn1.require(bp.connect.CONN_MAT)
- print()
- print(mat1)
+ for num in [0.5, 3]:
+ conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=5, post_size=4)
+ mat1 = conn1.require(bp.connect.CONN_MAT)
+ print()
+ print(mat1)
def test_random_fix_pre3():
- with pytest.raises(bp.errors.ConnectorError):
- conn1 = bp.connect.FixedPreNum(num=6, seed=1234)(pre_size=3, post_size=4)
- conn1.require(bp.connect.CONN_MAT)
+ with pytest.raises(bp.errors.ConnectorError):
+ conn1 = bp.connect.FixedPreNum(num=6, seed=1234)(pre_size=3, post_size=4)
+ conn1.require(bp.connect.CONN_MAT)
def test_random_fix_post1():
- for num in [0.4, 20]:
- conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
- mat1 = conn1.require(bp.connect.CONN_MAT)
+ for num in [0.4, 20]:
+ conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
+ mat1 = conn1.require(bp.connect.CONN_MAT)
- conn2 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
- mat2 = conn2.require(bp.connect.CONN_MAT)
+ conn2 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20))
+ mat2 = conn2.require(bp.connect.CONN_MAT)
- print()
- print('conn_mat 1\n', mat1)
- print('conn_mat 2\n', mat2)
+ print()
+ print('conn_mat 1\n', mat1)
+ print('conn_mat 2\n', mat2)
- assert bp.math.array_equal(mat1, mat2)
+ assert bp.math.array_equal(mat1, mat2)
def test_random_fix_post2():
- for num in [0.5, 3]:
- conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=5, post_size=4)
- mat1 = conn1.require(bp.connect.CONN_MAT)
- print(mat1)
+ for num in [0.5, 3]:
+ conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=5, post_size=4)
+ mat1 = conn1.require(bp.connect.CONN_MAT)
+ print(mat1)
def test_random_fix_post3():
- with pytest.raises(bp.errors.ConnectorError):
- conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4)
- conn1.require(bp.connect.CONN_MAT)
+ with pytest.raises(bp.errors.ConnectorError):
+ conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4)
+ conn1.require(bp.connect.CONN_MAT)
def test_gaussian_prob1():
- conn = bp.connect.GaussianProb(sigma=1., include_self=False)(pre_size=100)
- mat = conn.require(bp.connect.CONN_MAT)
+ conn = bp.connect.GaussianProb(sigma=1., include_self=False)(pre_size=100)
+ mat = conn.require(bp.connect.CONN_MAT)
- print()
- print('conn_mat', mat)
+ print()
+ print('conn_mat', mat)
def test_gaussian_prob2():
- conn = bp.connect.GaussianProb(sigma=4)(pre_size=(50, 50))
- mat = conn.require(bp.connect.CONN_MAT)
+ conn = bp.connect.GaussianProb(sigma=4)(pre_size=(50, 50))
+ mat = conn.require(bp.connect.CONN_MAT)
- print()
- print('conn_mat', mat)
+ print()
+ print('conn_mat', mat)
def test_gaussian_prob3():
- conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(50, 50))
- mat = conn.require(bp.connect.CONN_MAT)
+ conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(50, 50))
+ mat = conn.require(bp.connect.CONN_MAT)
- print()
- print('conn_mat', mat)
+ print()
+ print('conn_mat', mat)
def test_gaussian_prob4():
- conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(10, 10, 10))
- conn.require(bp.connect.CONN_MAT,
- bp.connect.PRE_IDS, bp.connect.POST_IDS,
- bp.connect.PRE2POST, bp.connect.POST_IDS)
+ conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(10, 10, 10))
+ conn.require(bp.connect.CONN_MAT,
+ bp.connect.PRE_IDS, bp.connect.POST_IDS,
+ bp.connect.PRE2POST, bp.connect.POST_IDS)
def test_SmallWorld1():
- conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False)
- conn(pre_size=10, post_size=10)
+ conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False)
+ conn(pre_size=10, post_size=10)
- mat = conn.require(bp.connect.CONN_MAT)
+ mat = conn.require(bp.connect.CONN_MAT)
- print('conn_mat', mat)
+ print('conn_mat', mat)
def test_SmallWorld3():
- conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=True)
- conn(pre_size=20, post_size=20)
+ conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=True)
+ conn(pre_size=20, post_size=20)
- mat = conn.require(bp.connect.CONN_MAT)
+ mat = conn.require(bp.connect.CONN_MAT)
- print('conn_mat', mat)
+ print('conn_mat', mat)
def test_SmallWorld2():
- conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5)
- conn(pre_size=(100,), post_size=(100,))
- mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
- bp.connect.PRE_IDS, bp.connect.POST_IDS,
- bp.connect.PRE2POST, bp.connect.POST_IDS)
- print()
- print('conn_mat', mat)
-
-
-def test_ScaleFreeBA():
- conn = bp.connect.ScaleFreeBA(m=2)
- for size in [100, (10, 20), (2, 10, 20)]:
- conn(pre_size=size, post_size=size)
+ conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5)
+ conn(pre_size=(100,), post_size=(100,))
mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
bp.connect.PRE_IDS, bp.connect.POST_IDS,
bp.connect.PRE2POST, bp.connect.POST_IDS)
@@ -158,23 +147,44 @@ def test_ScaleFreeBA():
print('conn_mat', mat)
-def test_ScaleFreeBADual():
- conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4)
- for size in [100, (10, 20), (2, 10, 20)]:
- conn(pre_size=size, post_size=size)
- mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
- bp.connect.PRE_IDS, bp.connect.POST_IDS,
- bp.connect.PRE2POST, bp.connect.POST_IDS)
- print()
- print('conn_mat', mat)
+def test_ScaleFreeBA():
+ conn = bp.connect.ScaleFreeBA(m=2)
+ for size in [100, (10, 20), (2, 10, 20)]:
+ conn(pre_size=size, post_size=size)
+ mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
+ bp.connect.PRE_IDS, bp.connect.POST_IDS,
+ bp.connect.PRE2POST, bp.connect.POST_IDS)
+ print()
+ print('conn_mat', mat)
-def test_PowerLaw():
- conn = bp.connect.PowerLaw(m=3, p=0.4)
- for size in [100, (10, 20), (2, 10, 20)]:
- conn(pre_size=size, post_size=size)
- mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
- bp.connect.PRE_IDS, bp.connect.POST_IDS,
- bp.connect.PRE2POST, bp.connect.POST_IDS)
+def test_ScaleFreeBADual():
+ conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4)
+ for size in [100, (10, 20), (2, 10, 20)]:
+ conn(pre_size=size, post_size=size)
+ mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
+ bp.connect.PRE_IDS, bp.connect.POST_IDS,
+ bp.connect.PRE2POST, bp.connect.POST_IDS)
print()
print('conn_mat', mat)
+
+
+def test_PowerLaw():
+ conn = bp.connect.PowerLaw(m=3, p=0.4)
+ for size in [100, (10, 20), (2, 10, 20)]:
+ conn(pre_size=size, post_size=size)
+ mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT,
+ bp.connect.PRE_IDS, bp.connect.POST_IDS,
+ bp.connect.PRE2POST, bp.connect.POST_IDS)
+ print()
+ print('conn_mat', mat)
+
+
+def test_prob_dist():
+ conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True)
+ for size in [100, (10, 20), (2, 10, 20), (2, 3, 4, 5)]:
+ conn(pre_size=size, post_size=size)
+ pre_ids, post_ids = conn.build_coo()
+ print()
+ print('Pre Ids:', pre_ids)
+ print('Post Ids:', post_ids)
From 8038108ca9152586f28c47cc8b0fda8cea7c04ec Mon Sep 17 00:00:00 2001
From: Routhleck <1310722434@qq.com>
Date: Thu, 29 Jun 2023 08:56:56 +0800
Subject: [PATCH 002/326] Update base.py
---
brainpy/_src/connect/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/brainpy/_src/connect/base.py b/brainpy/_src/connect/base.py
index 9df2efd76..858fc54a7 100644
--- a/brainpy/_src/connect/base.py
+++ b/brainpy/_src/connect/base.py
@@ -726,7 +726,7 @@ def coo2csc(coo, post_num, data=None):
return pre_ids_new, indptr_new, data_new
-def visualizeMat(mat, description):
+def visualizeMat(mat, description='Untitled'):
try:
import seaborn as sns
import matplotlib.pyplot as plt
From ade0faf2a9b1780c4b1ef024d65fa01b9319b8bc Mon Sep 17 00:00:00 2001
From: Routhleck <1310722434@qq.com>
Date: Thu, 29 Jun 2023 09:23:26 +0800
Subject: [PATCH 003/326] Test all connector time used
---
brainpy/_src/connect/tests/test_all_time.py | 329 ++++++++++++++++++++
1 file changed, 329 insertions(+)
create mode 100644 brainpy/_src/connect/tests/test_all_time.py
diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py
new file mode 100644
index 000000000..f11927dae
--- /dev/null
+++ b/brainpy/_src/connect/tests/test_all_time.py
@@ -0,0 +1,329 @@
+import time
+import brainpy as bp
+import unittest
+import pytest
+import pandas as pd
+
+df = pd.DataFrame(
+ columns=['connector name', 'superclass', 'connect matrix size', 'build function', 'other parameter',
+ 'time(ms)'])
+
+size_same = [100, 500, 2500, 12500, 25000, 37500, 50000]
+size_diff = [(10, 100), (100, 1000), (1000, 10000), (10000, 100000)]
+
+
+def get_ms(value):
+ return round(value * 1000, 4)
+
+
+class OneEndConnector(unittest.TestCase):
+ def test_gaussian_prob(self):
+ for size in size_same:
+ conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=size)
+ mat = conn.build_mat()
+ start = time.time()
+ mat = conn.build_mat()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['GaussianProb',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'sigma=1/include_self=False',
+ time_used]
+
+ def test_grid(self):
+ for size in size_same:
+ conn = bp.connect.GridFour(include_self=False, periodic_boundary=False)(size, size)
+ start = time.time()
+ mat = conn.build_mat()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['GridFour',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'include_self=False/periodic_boundary=False',
+ time_used]
+ start = time.time()
+ pre_ids, post_ids = conn.build_coo()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['GridFour',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'include_self=False/periodic_boundary=False',
+ time_used]
+
+
+class TwoEndConnector(unittest.TestCase):
+ def test_fixed_prob(self):
+ for size in size_same:
+ conn = bp.connect.FixedProb(prob=0.1, seed=123)
+ conn(pre_size=size, post_size=size)
+ start = time.time()
+ mat = conn.build_mat()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedProb',
+ 'TwoEndConnector',
+ f'{size}×{size}',
+ 'build_mat',
+ 'prob=0.1',
+ time_used]
+
+ start = time.time()
+ pre_ids, post_ids = conn.build_coo()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedProb',
+ 'TwoEndConnector',
+ f'{size}×{size}',
+ 'build_coo',
+ 'prob=0.1',
+ time_used]
+
+ start = time.time()
+ pre_ids, post_ids = conn.build_csr()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedProb',
+ 'TwoEndConnector',
+ f'{size}×{size}',
+ 'build_csr',
+ 'prob=0.1',
+ time_used]
+
+ for size in size_diff:
+ conn = bp.connect.FixedProb(prob=0.1, seed=123)
+ conn(pre_size=size[0], post_size=size[1])
+ start = time.time()
+ mat = conn.build_mat()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedProb',
+ 'TwoEndConnector',
+ f'{size[0]}×{size[1]}',
+ 'build_mat',
+ 'prob=0.1',
+ time_used]
+
+ start = time.time()
+ pre_ids, post_ids = conn.build_coo()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedProb',
+ 'TwoEndConnector',
+ f'{size[0]}×{size[1]}',
+ 'build_coo',
+ 'prob=0.1',
+ time_used]
+
+ start = time.time()
+ pre_ids, post_ids = conn.build_csr()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedProb',
+ 'TwoEndConnector',
+ f'{size[0]}×{size[1]}',
+ 'build_csr',
+ 'prob=0.1',
+ time_used]
+
+ def test_fixed_pre_num(self):
+ for size in size_same:
+ conn = bp.connect.FixedPreNum(num=0.4, seed=123)
+ conn(pre_size=size, post_size=size)
+ start = time.time()
+ mat = conn.require(bp.connect.CONN_MAT)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedPreNum',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'pre_num=10',
+ time_used]
+
+ start = time.time()
+ pre_ids, post_ids = conn.build_coo()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedPreNum',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'pre_num=10',
+ time_used]
+
+ for size in size_diff:
+ conn = bp.connect.FixedPreNum(num=0.4, seed=123)
+ conn(pre_size=size[0], post_size=size[1])
+ start = time.time()
+ mat = conn.require(bp.connect.CONN_MAT)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedPreNum',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_mat',
+ 'pre_num=10',
+ time_used]
+
+ start = time.time()
+ pre_ids, post_ids = conn.build_coo()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedPreNum',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_coo',
+ 'pre_num=10',
+ time_used]
+
+ def test_fixed_post_num(self):
+ for size in size_same:
+ conn = bp.connect.FixedPostNum(num=10, seed=123)
+ conn(pre_size=size, post_size=size)
+ start = time.time()
+ mat = conn.require(bp.connect.CONN_MAT)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedPreNum',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'num=10',
+ time_used]
+
+ start = time.time()
+ pre_ids, post_ids = conn.build_coo()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedPreNum',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'num=10',
+ time_used]
+
+ for size in size_diff:
+ conn = bp.connect.FixedPreNum(num=10, seed=123)
+ conn(pre_size=size[0], post_size=size[1])
+ start = time.time()
+ mat = conn.require(bp.connect.CONN_MAT)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedPreNum',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_mat',
+ 'pre_num=10',
+ time_used]
+
+ start = time.time()
+ pre_ids, post_ids = conn.build_coo()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedPreNum',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_coo',
+ 'pre_num=10',
+ time_used]
+
+ def test_prob_dist(self):
+ for size in size_same:
+ conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True)
+ conn(pre_size=size, post_size=size)
+ start = time.time()
+ mat = conn.require(bp.connect.CONN_MAT)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['ProbDist',
+ 'TwoEndConnector',
+ f'{size}×{size}',
+ 'build_mat',
+ 'prob=0.5',
+ time_used]
+
+ start = time.time()
+ pre_ids, post_ids = conn.build_coo()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['ProbDist',
+ 'TwoEndConnector',
+ f'{size}×{size}',
+ 'build_coo',
+ 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True',
+ time_used]
+
+ def test_small_world(self):
+ for size in size_same:
+ conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False)
+ conn(pre_size=size, post_size=size)
+ start = time.time()
+ mat = conn.require(bp.connect.CONN_MAT)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['SmallWorld',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'num_neighbor=2/prob=0.5/include_self=False',
+ time_used]
+
+ def test_scale_free_ba(self):
+ for size in size_same:
+ conn = bp.connect.ScaleFreeBA(m=2)
+ conn(pre_size=size, post_size=size)
+ start = time.time()
+ mat = conn.require(bp.connect.CONN_MAT)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['ScaleFreeBA',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'm=2',
+ time_used]
+
+ def test_scale_free_ba_dual(self):
+ for size in size_same:
+ conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4)
+ conn(pre_size=size, post_size=size)
+ start = time.time()
+ mat = conn.require(bp.connect.CONN_MAT)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['ScaleFreeBADual',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'm1=2/m2=3/p=0.4',
+ time_used]
+
+ def test_power_law(self):
+ for size in size_same:
+ conn = bp.connect.PowerLaw(m=3, p=0.4)
+ conn(pre_size=size, post_size=size)
+ start = time.time()
+ mat = conn.require(bp.connect.CONN_MAT)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['PowerLaw',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'm=3/p=0.4',
+ time_used]
+
+ def test_one2one(self):
+ for size in size_same:
+ conn = bp.connect.One2One()
+ conn(pre_size=size, post_size=size)
+ start = time.time()
+ mat = conn.build_mat()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['One2One',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ '',
+ time_used]
+
+ def test_all2all(self):
+ for size in size_same:
+ conn = bp.connect.All2All()
+ conn(pre_size=size, post_size=size)
+ start = time.time()
+ mat = conn.build_mat()
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['All2All',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ '',
+ time_used]
+
+class TestSave(unittest.TestCase):
+ def test_save(self):
+ df.to_csv('time.csv', index=False)
From 13a7b706a05f77500c5f6678a6aef8fd7cd46ff4 Mon Sep 17 00:00:00 2001
From: He Sichao <1310722434@qq.com>
Date: Thu, 29 Jun 2023 12:56:58 +0800
Subject: [PATCH 004/326] Update test_all_time.py
---
brainpy/_src/connect/tests/test_all_time.py | 336 ++++++++++++++++++--
1 file changed, 310 insertions(+), 26 deletions(-)
diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py
index f11927dae..93464e61b 100644
--- a/brainpy/_src/connect/tests/test_all_time.py
+++ b/brainpy/_src/connect/tests/test_all_time.py
@@ -20,9 +20,9 @@ class OneEndConnector(unittest.TestCase):
def test_gaussian_prob(self):
for size in size_same:
conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=size)
- mat = conn.build_mat()
+
start = time.time()
- mat = conn.build_mat()
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['GaussianProb',
'OneEndConnector',
@@ -31,11 +31,32 @@ def test_gaussian_prob(self):
'sigma=1/include_self=False',
time_used]
- def test_grid(self):
+ start = time.time()
+ conn.require(bp.connect.COO)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['GaussianProb',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'sigma=1/include_self=False',
+ time_used]
+
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['GaussianProb',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'sigma=1/include_self=False',
+ time_used]
+
+ def test_grid_four(self):
for size in size_same:
conn = bp.connect.GridFour(include_self=False, periodic_boundary=False)(size, size)
+
start = time.time()
- mat = conn.build_mat()
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['GridFour',
'OneEndConnector',
@@ -43,24 +64,104 @@ def test_grid(self):
'build_mat',
'include_self=False/periodic_boundary=False',
time_used]
+
+ start = time.time()
+ conn.require(bp.connect.COO)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['GridFour',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'include_self=False/periodic_boundary=False',
+ time_used]
+
start = time.time()
- pre_ids, post_ids = conn.build_coo()
+ conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['GridFour',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'include_self=False/periodic_boundary=False',
+ time_used]
+
+ def test_grid_eight(self):
+ for size in size_same:
+ conn = bp.connect.GridEight(include_self=False, periodic_boundary=False)(size, size)
+
+ start = time.time()
+ conn.require(bp.connect.CONN_MAT)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['GridEight',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'include_self=False/periodic_boundary=False',
+ time_used]
+
+ start = time.time()
+ conn.require(bp.connect.COO)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['GridEight',
'OneEndConnector',
f'{size}x{size}',
'build_coo',
'include_self=False/periodic_boundary=False',
time_used]
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['GridEight',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'include_self=False/periodic_boundary=False',
+ time_used]
+
+ def test_grid_n(self):
+ for size in size_same:
+ conn = bp.connect.GridN(include_self=False, periodic_boundary=False, N=2)(size, size)
+
+ start = time.time()
+ conn.require(bp.connect.CONN_MAT)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['GridN',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'include_self=False/periodic_boundary=False/N=2',
+ time_used]
+
+ start = time.time()
+ conn.require(bp.connect.COO)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['GridN',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'include_self=False/periodic_boundary=False/N=2',
+ time_used]
+
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['GridN',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'include_self=False/periodic_boundary=False/N=2',
+ time_used]
+
class TwoEndConnector(unittest.TestCase):
def test_fixed_prob(self):
for size in size_same:
conn = bp.connect.FixedProb(prob=0.1, seed=123)
conn(pre_size=size, post_size=size)
+
start = time.time()
- mat = conn.build_mat()
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedProb',
'TwoEndConnector',
@@ -70,7 +171,7 @@ def test_fixed_prob(self):
time_used]
start = time.time()
- pre_ids, post_ids = conn.build_coo()
+ conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedProb',
'TwoEndConnector',
@@ -80,7 +181,7 @@ def test_fixed_prob(self):
time_used]
start = time.time()
- pre_ids, post_ids = conn.build_csr()
+ conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedProb',
'TwoEndConnector',
@@ -92,8 +193,9 @@ def test_fixed_prob(self):
for size in size_diff:
conn = bp.connect.FixedProb(prob=0.1, seed=123)
conn(pre_size=size[0], post_size=size[1])
+
start = time.time()
- mat = conn.build_mat()
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedProb',
'TwoEndConnector',
@@ -103,7 +205,7 @@ def test_fixed_prob(self):
time_used]
start = time.time()
- pre_ids, post_ids = conn.build_coo()
+ conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedProb',
'TwoEndConnector',
@@ -113,7 +215,7 @@ def test_fixed_prob(self):
time_used]
start = time.time()
- pre_ids, post_ids = conn.build_csr()
+ conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedProb',
'TwoEndConnector',
@@ -126,8 +228,9 @@ def test_fixed_pre_num(self):
for size in size_same:
conn = bp.connect.FixedPreNum(num=0.4, seed=123)
conn(pre_size=size, post_size=size)
+
start = time.time()
- mat = conn.require(bp.connect.CONN_MAT)
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedPreNum',
'TwoEndConnector',
@@ -137,7 +240,7 @@ def test_fixed_pre_num(self):
time_used]
start = time.time()
- pre_ids, post_ids = conn.build_coo()
+ conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedPreNum',
'TwoEndConnector',
@@ -146,11 +249,22 @@ def test_fixed_pre_num(self):
'pre_num=10',
time_used]
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedPreNum',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'pre_num=10',
+ time_used]
+
for size in size_diff:
conn = bp.connect.FixedPreNum(num=0.4, seed=123)
conn(pre_size=size[0], post_size=size[1])
+
start = time.time()
- mat = conn.require(bp.connect.CONN_MAT)
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedPreNum',
'TwoEndConnector',
@@ -160,7 +274,7 @@ def test_fixed_pre_num(self):
time_used]
start = time.time()
- pre_ids, post_ids = conn.build_coo()
+ conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedPreNum',
'TwoEndConnector',
@@ -169,10 +283,21 @@ def test_fixed_pre_num(self):
'pre_num=10',
time_used]
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedPreNum',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_csr',
+ 'pre_num=10',
+ time_used]
+
def test_fixed_post_num(self):
for size in size_same:
conn = bp.connect.FixedPostNum(num=10, seed=123)
conn(pre_size=size, post_size=size)
+
start = time.time()
mat = conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
@@ -184,7 +309,7 @@ def test_fixed_post_num(self):
time_used]
start = time.time()
- pre_ids, post_ids = conn.build_coo()
+ conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedPreNum',
'TwoEndConnector',
@@ -193,11 +318,22 @@ def test_fixed_post_num(self):
'num=10',
time_used]
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedPreNum',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'num=10',
+ time_used]
+
for size in size_diff:
conn = bp.connect.FixedPreNum(num=10, seed=123)
conn(pre_size=size[0], post_size=size[1])
+
start = time.time()
- mat = conn.require(bp.connect.CONN_MAT)
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedPreNum',
'TwoEndConnector',
@@ -207,7 +343,7 @@ def test_fixed_post_num(self):
time_used]
start = time.time()
- pre_ids, post_ids = conn.build_coo()
+ conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedPreNum',
'TwoEndConnector',
@@ -216,12 +352,23 @@ def test_fixed_post_num(self):
'pre_num=10',
time_used]
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['FixedPreNum',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_csr',
+ 'pre_num=10',
+ time_used]
+
def test_prob_dist(self):
for size in size_same:
conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True)
conn(pre_size=size, post_size=size)
+
start = time.time()
- mat = conn.require(bp.connect.CONN_MAT)
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['ProbDist',
'TwoEndConnector',
@@ -231,7 +378,7 @@ def test_prob_dist(self):
time_used]
start = time.time()
- pre_ids, post_ids = conn.build_coo()
+ conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['ProbDist',
'TwoEndConnector',
@@ -240,12 +387,23 @@ def test_prob_dist(self):
'dist=1|prob=0.5|pre_ratio=0.3|include_self=True',
time_used]
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['ProbDist',
+ 'TwoEndConnector',
+ f'{size}×{size}',
+ 'build_csr',
+ 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True',
+ time_used]
+
def test_small_world(self):
for size in size_same:
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False)
conn(pre_size=size, post_size=size)
+
start = time.time()
- mat = conn.require(bp.connect.CONN_MAT)
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['SmallWorld',
'TwoEndConnector',
@@ -254,12 +412,33 @@ def test_small_world(self):
'num_neighbor=2/prob=0.5/include_self=False',
time_used]
+ start = time.time()
+ conn.require(bp.connect.COO)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['SmallWorld',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'num_neighbor=2/prob=0.5/include_self=False',
+ time_used]
+
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['SmallWorld',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'num_neighbor=2/prob=0.5/include_self=False',
+ time_used]
+
def test_scale_free_ba(self):
for size in size_same:
conn = bp.connect.ScaleFreeBA(m=2)
conn(pre_size=size, post_size=size)
+
start = time.time()
- mat = conn.require(bp.connect.CONN_MAT)
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['ScaleFreeBA',
'TwoEndConnector',
@@ -268,12 +447,33 @@ def test_scale_free_ba(self):
'm=2',
time_used]
+ start = time.time()
+ conn.require(bp.connect.COO)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['ScaleFreeBA',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'm=2',
+ time_used]
+
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['ScaleFreeBA',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'm=2',
+ time_used]
+
def test_scale_free_ba_dual(self):
for size in size_same:
conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4)
conn(pre_size=size, post_size=size)
+
start = time.time()
- mat = conn.require(bp.connect.CONN_MAT)
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['ScaleFreeBADual',
'TwoEndConnector',
@@ -282,12 +482,33 @@ def test_scale_free_ba_dual(self):
'm1=2/m2=3/p=0.4',
time_used]
+ start = time.time()
+ conn.require(bp.connect.COO)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['ScaleFreeBADual',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'm1=2/m2=3/p=0.4',
+ time_used]
+
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['ScaleFreeBADual',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'm1=2/m2=3/p=0.4',
+ time_used]
+
def test_power_law(self):
for size in size_same:
conn = bp.connect.PowerLaw(m=3, p=0.4)
conn(pre_size=size, post_size=size)
+
start = time.time()
- mat = conn.require(bp.connect.CONN_MAT)
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['PowerLaw',
'TwoEndConnector',
@@ -296,12 +517,33 @@ def test_power_law(self):
'm=3/p=0.4',
time_used]
+ start = time.time()
+ conn.require(bp.connect.COO)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['PowerLaw',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'm=3/p=0.4',
+ time_used]
+
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['PowerLaw',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'm=3/p=0.4',
+ time_used]
+
def test_one2one(self):
for size in size_same:
conn = bp.connect.One2One()
conn(pre_size=size, post_size=size)
+
start = time.time()
- mat = conn.build_mat()
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['One2One',
'TwoEndConnector',
@@ -310,12 +552,33 @@ def test_one2one(self):
'',
time_used]
+ start = time.time()
+ conn.require(bp.connect.COO)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['One2One',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ '',
+ time_used]
+
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['One2One',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ '',
+ time_used]
+
def test_all2all(self):
for size in size_same:
conn = bp.connect.All2All()
conn(pre_size=size, post_size=size)
+
start = time.time()
- mat = conn.build_mat()
+ conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['All2All',
'TwoEndConnector',
@@ -324,6 +587,27 @@ def test_all2all(self):
'',
time_used]
+ start = time.time()
+ conn.require(bp.connect.COO)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['All2All',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ '',
+ time_used]
+
+ start = time.time()
+ conn.require(bp.connect.CSR)
+ time_used = get_ms(time.time() - start)
+ df.loc[len(df)] = ['All2All',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ '',
+ time_used]
+
+
class TestSave(unittest.TestCase):
def test_save(self):
df.to_csv('time.csv', index=False)
From 3caf39665f392f81e8d8b1d5944d61c1d9b64e9b Mon Sep 17 00:00:00 2001
From: He Sichao <1310722434@qq.com>
Date: Thu, 29 Jun 2023 13:01:07 +0800
Subject: [PATCH 005/326] Update test_all_time.py
---
brainpy/_src/connect/tests/test_all_time.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py
index 93464e61b..7c735ec01 100644
--- a/brainpy/_src/connect/tests/test_all_time.py
+++ b/brainpy/_src/connect/tests/test_all_time.py
@@ -1,4 +1,6 @@
import time
+from datetime import datetime
+
import brainpy as bp
import unittest
import pytest
@@ -610,4 +612,5 @@ def test_all2all(self):
class TestSave(unittest.TestCase):
def test_save(self):
- df.to_csv('time.csv', index=False)
+ df.to_csv('connector_time_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv',
+ index=False)
From 625dd0aa71b4023d0658033861937d0516064530 Mon Sep 17 00:00:00 2001
From: He Sichao <1310722434@qq.com>
Date: Thu, 29 Jun 2023 13:22:05 +0800
Subject: [PATCH 006/326] Update test_all_time.py
---
brainpy/_src/connect/tests/test_all_time.py | 36 ++++++++++-----------
1 file changed, 18 insertions(+), 18 deletions(-)
diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py
index 7c735ec01..4888cde92 100644
--- a/brainpy/_src/connect/tests/test_all_time.py
+++ b/brainpy/_src/connect/tests/test_all_time.py
@@ -33,15 +33,15 @@ def test_gaussian_prob(self):
'sigma=1/include_self=False',
time_used]
- start = time.time()
- conn.require(bp.connect.COO)
- time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['GaussianProb',
- 'OneEndConnector',
- f'{size}x{size}',
- 'build_coo',
- 'sigma=1/include_self=False',
- time_used]
+ # start = time.time()
+ # conn.require(bp.connect.COO)
+ # time_used = get_ms(time.time() - start)
+ # df.loc[len(df)] = ['GaussianProb',
+ # 'OneEndConnector',
+ # f'{size}x{size}',
+ # 'build_coo',
+ # 'sigma=1/include_self=False',
+ # time_used]
start = time.time()
conn.require(bp.connect.CSR)
@@ -589,15 +589,15 @@ def test_all2all(self):
'',
time_used]
- start = time.time()
- conn.require(bp.connect.COO)
- time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['All2All',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_coo',
- '',
- time_used]
+ # start = time.time()
+ # conn.require(bp.connect.COO)
+ # time_used = get_ms(time.time() - start)
+ # df.loc[len(df)] = ['All2All',
+ # 'TwoEndConnector',
+ # f'{size}x{size}',
+ # 'build_coo',
+ # '',
+ # time_used]
start = time.time()
conn.require(bp.connect.CSR)
From ba226e176784c8b868b191f39dd4a2228c7f34dd Mon Sep 17 00:00:00 2001
From: He Sichao <1310722434@qq.com>
Date: Thu, 29 Jun 2023 14:40:14 +0800
Subject: [PATCH 007/326] Update test_all_time.py
---
brainpy/_src/connect/tests/test_all_time.py | 54 ++++++++++++++++-----
1 file changed, 43 insertions(+), 11 deletions(-)
diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py
index 4888cde92..aa046a1e4 100644
--- a/brainpy/_src/connect/tests/test_all_time.py
+++ b/brainpy/_src/connect/tests/test_all_time.py
@@ -10,8 +10,9 @@
columns=['connector name', 'superclass', 'connect matrix size', 'build function', 'other parameter',
'time(ms)'])
-size_same = [100, 500, 2500, 12500, 25000, 37500, 50000]
-size_diff = [(10, 100), (100, 1000), (1000, 10000), (10000, 100000)]
+# size_same = [100, 500, 2500, 12500, 25000, 37500, 50000]
+size_same = [100, 500, 2500, 12500]
+size_diff = [(10, 100), (100, 1000), (1000, 10000)]
def get_ms(value):
@@ -20,7 +21,9 @@ def get_ms(value):
class OneEndConnector(unittest.TestCase):
def test_gaussian_prob(self):
+ print()
for size in size_same:
+ print('GaussianProb:', size)
conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=size)
start = time.time()
@@ -54,7 +57,9 @@ def test_gaussian_prob(self):
time_used]
def test_grid_four(self):
+ print()
for size in size_same:
+ print('GridFour:', size)
conn = bp.connect.GridFour(include_self=False, periodic_boundary=False)(size, size)
start = time.time()
@@ -88,7 +93,9 @@ def test_grid_four(self):
time_used]
def test_grid_eight(self):
+ print()
for size in size_same:
+ print('GridEight:', size)
conn = bp.connect.GridEight(include_self=False, periodic_boundary=False)(size, size)
start = time.time()
@@ -122,7 +129,9 @@ def test_grid_eight(self):
time_used]
def test_grid_n(self):
+ print()
for size in size_same:
+ print('GridN:', size)
conn = bp.connect.GridN(include_self=False, periodic_boundary=False, N=2)(size, size)
start = time.time()
@@ -158,7 +167,9 @@ def test_grid_n(self):
class TwoEndConnector(unittest.TestCase):
def test_fixed_prob(self):
+ print()
for size in size_same:
+ print('FixedProb:', size)
conn = bp.connect.FixedProb(prob=0.1, seed=123)
conn(pre_size=size, post_size=size)
@@ -167,7 +178,7 @@ def test_fixed_prob(self):
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedProb',
'TwoEndConnector',
- f'{size}×{size}',
+ f'{size}x{size}',
'build_mat',
'prob=0.1',
time_used]
@@ -177,7 +188,7 @@ def test_fixed_prob(self):
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedProb',
'TwoEndConnector',
- f'{size}×{size}',
+ f'{size}x{size}',
'build_coo',
'prob=0.1',
time_used]
@@ -187,12 +198,13 @@ def test_fixed_prob(self):
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedProb',
'TwoEndConnector',
- f'{size}×{size}',
+ f'{size}x{size}',
'build_csr',
'prob=0.1',
time_used]
for size in size_diff:
+ print('FixedProb:', size)
conn = bp.connect.FixedProb(prob=0.1, seed=123)
conn(pre_size=size[0], post_size=size[1])
@@ -201,7 +213,7 @@ def test_fixed_prob(self):
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedProb',
'TwoEndConnector',
- f'{size[0]}×{size[1]}',
+ f'{size[0]}x{size[1]}',
'build_mat',
'prob=0.1',
time_used]
@@ -211,7 +223,7 @@ def test_fixed_prob(self):
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedProb',
'TwoEndConnector',
- f'{size[0]}×{size[1]}',
+ f'{size[0]}x{size[1]}',
'build_coo',
'prob=0.1',
time_used]
@@ -221,13 +233,15 @@ def test_fixed_prob(self):
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['FixedProb',
'TwoEndConnector',
- f'{size[0]}×{size[1]}',
+ f'{size[0]}x{size[1]}',
'build_csr',
'prob=0.1',
time_used]
def test_fixed_pre_num(self):
+ print()
for size in size_same:
+ print('FixedPreNum:', size)
conn = bp.connect.FixedPreNum(num=0.4, seed=123)
conn(pre_size=size, post_size=size)
@@ -262,6 +276,7 @@ def test_fixed_pre_num(self):
time_used]
for size in size_diff:
+ print('FixedPreNum:', size)
conn = bp.connect.FixedPreNum(num=0.4, seed=123)
conn(pre_size=size[0], post_size=size[1])
@@ -296,7 +311,9 @@ def test_fixed_pre_num(self):
time_used]
def test_fixed_post_num(self):
+ print()
for size in size_same:
+ print('FixedPostNum:', size)
conn = bp.connect.FixedPostNum(num=10, seed=123)
conn(pre_size=size, post_size=size)
@@ -331,6 +348,7 @@ def test_fixed_post_num(self):
time_used]
for size in size_diff:
+ print('FixedPostNum:', size)
conn = bp.connect.FixedPreNum(num=10, seed=123)
conn(pre_size=size[0], post_size=size[1])
@@ -365,7 +383,9 @@ def test_fixed_post_num(self):
time_used]
def test_prob_dist(self):
+ print()
for size in size_same:
+ print('ProbDist:', size)
conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True)
conn(pre_size=size, post_size=size)
@@ -374,7 +394,7 @@ def test_prob_dist(self):
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['ProbDist',
'TwoEndConnector',
- f'{size}×{size}',
+ f'{size}x{size}',
'build_mat',
'prob=0.5',
time_used]
@@ -384,7 +404,7 @@ def test_prob_dist(self):
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['ProbDist',
'TwoEndConnector',
- f'{size}×{size}',
+ f'{size}x{size}',
'build_coo',
'dist=1|prob=0.5|pre_ratio=0.3|include_self=True',
time_used]
@@ -394,13 +414,15 @@ def test_prob_dist(self):
time_used = get_ms(time.time() - start)
df.loc[len(df)] = ['ProbDist',
'TwoEndConnector',
- f'{size}×{size}',
+ f'{size}x{size}',
'build_csr',
'dist=1|prob=0.5|pre_ratio=0.3|include_self=True',
time_used]
def test_small_world(self):
+ print()
for size in size_same:
+ print('SmallWorld:', size)
conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False)
conn(pre_size=size, post_size=size)
@@ -435,7 +457,9 @@ def test_small_world(self):
time_used]
def test_scale_free_ba(self):
+ print()
for size in size_same:
+ print('ScaleFreeBA:', size)
conn = bp.connect.ScaleFreeBA(m=2)
conn(pre_size=size, post_size=size)
@@ -470,7 +494,9 @@ def test_scale_free_ba(self):
time_used]
def test_scale_free_ba_dual(self):
+ print()
for size in size_same:
+ print('ScaleFreeBADual:', size)
conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4)
conn(pre_size=size, post_size=size)
@@ -505,7 +531,9 @@ def test_scale_free_ba_dual(self):
time_used]
def test_power_law(self):
+ print()
for size in size_same:
+ print('PowerLaw:', size)
conn = bp.connect.PowerLaw(m=3, p=0.4)
conn(pre_size=size, post_size=size)
@@ -540,7 +568,9 @@ def test_power_law(self):
time_used]
def test_one2one(self):
+ print()
for size in size_same:
+ print('One2One:', size)
conn = bp.connect.One2One()
conn(pre_size=size, post_size=size)
@@ -575,7 +605,9 @@ def test_one2one(self):
time_used]
def test_all2all(self):
+ print()
for size in size_same:
+ print('All2All:', size)
conn = bp.connect.All2All()
conn(pre_size=size, post_size=size)
From 91cd1d3ec6195ec40cf391c9d402467655ee2cb6 Mon Sep 17 00:00:00 2001
From: He Sichao <1310722434@qq.com>
Date: Thu, 29 Jun 2023 14:45:38 +0800
Subject: [PATCH 008/326] Add try for import pandas
---
brainpy/_src/connect/tests/test_all_time.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py
index aa046a1e4..4e6d3bc76 100644
--- a/brainpy/_src/connect/tests/test_all_time.py
+++ b/brainpy/_src/connect/tests/test_all_time.py
@@ -4,7 +4,10 @@
import brainpy as bp
import unittest
import pytest
-import pandas as pd
+try:
+ import pandas as pd
+except (ImportError, ModuleNotFoundError):
+ print('No pandas installed, skip test.')
df = pd.DataFrame(
columns=['connector name', 'superclass', 'connect matrix size', 'build function', 'other parameter',
From caeba8593ce04d6fe477e0911db39dce143d301c Mon Sep 17 00:00:00 2001
From: He Sichao <1310722434@qq.com>
Date: Thu, 29 Jun 2023 15:18:25 +0800
Subject: [PATCH 009/326] Update test_all_time.py
---
brainpy/_src/connect/tests/test_all_time.py | 604 ++++++++++----------
1 file changed, 306 insertions(+), 298 deletions(-)
diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py
index 4e6d3bc76..2252e3a82 100644
--- a/brainpy/_src/connect/tests/test_all_time.py
+++ b/brainpy/_src/connect/tests/test_all_time.py
@@ -4,15 +4,16 @@
import brainpy as bp
import unittest
import pytest
+
try:
import pandas as pd
+
+ df = pd.DataFrame(
+ columns=['connector name', 'superclass', 'connect matrix size', 'build function', 'other parameter',
+ 'time(ms)'])
except (ImportError, ModuleNotFoundError):
print('No pandas installed, skip test.')
-df = pd.DataFrame(
- columns=['connector name', 'superclass', 'connect matrix size', 'build function', 'other parameter',
- 'time(ms)'])
-
# size_same = [100, 500, 2500, 12500, 25000, 37500, 50000]
size_same = [100, 500, 2500, 12500]
size_diff = [(10, 100), (100, 1000), (1000, 10000)]
@@ -22,6 +23,13 @@ def get_ms(value):
return round(value * 1000, 4)
+def insert_row(connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used):
+ try:
+ df.loc[len(df)] = [connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used]
+ except (NameError, UnboundLocalError):
+ print('No pandas installed, skip test.')
+
+
class OneEndConnector(unittest.TestCase):
def test_gaussian_prob(self):
print()
@@ -32,12 +40,12 @@ def test_gaussian_prob(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['GaussianProb',
- 'OneEndConnector',
- f'{size}x{size}',
- 'build_mat',
- 'sigma=1/include_self=False',
- time_used]
+ insert_row('GaussianProb',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'sigma=1/include_self=False',
+ time_used)
# start = time.time()
# conn.require(bp.connect.COO)
@@ -52,12 +60,12 @@ def test_gaussian_prob(self):
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['GaussianProb',
- 'OneEndConnector',
- f'{size}x{size}',
- 'build_csr',
- 'sigma=1/include_self=False',
- time_used]
+ insert_row('GaussianProb',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'sigma=1/include_self=False',
+ time_used)
def test_grid_four(self):
print()
@@ -68,32 +76,32 @@ def test_grid_four(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['GridFour',
- 'OneEndConnector',
- f'{size}x{size}',
- 'build_mat',
- 'include_self=False/periodic_boundary=False',
- time_used]
+ insert_row('GridFour',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'include_self=False/periodic_boundary=False',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['GridFour',
- 'OneEndConnector',
- f'{size}x{size}',
- 'build_coo',
- 'include_self=False/periodic_boundary=False',
- time_used]
+ insert_row('GridFour',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'include_self=False/periodic_boundary=False',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['GridFour',
- 'OneEndConnector',
- f'{size}x{size}',
- 'build_csr',
- 'include_self=False/periodic_boundary=False',
- time_used]
+ insert_row('GridFour',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'include_self=False/periodic_boundary=False',
+ time_used)
def test_grid_eight(self):
print()
@@ -104,32 +112,32 @@ def test_grid_eight(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['GridEight',
- 'OneEndConnector',
- f'{size}x{size}',
- 'build_mat',
- 'include_self=False/periodic_boundary=False',
- time_used]
+ insert_row('GridEight',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'include_self=False/periodic_boundary=False',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['GridEight',
- 'OneEndConnector',
- f'{size}x{size}',
- 'build_coo',
- 'include_self=False/periodic_boundary=False',
- time_used]
+ insert_row('GridEight',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'include_self=False/periodic_boundary=False',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['GridEight',
- 'OneEndConnector',
- f'{size}x{size}',
- 'build_csr',
- 'include_self=False/periodic_boundary=False',
- time_used]
+ insert_row('GridEight',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'include_self=False/periodic_boundary=False',
+ time_used)
def test_grid_n(self):
print()
@@ -140,32 +148,32 @@ def test_grid_n(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['GridN',
- 'OneEndConnector',
- f'{size}x{size}',
- 'build_mat',
- 'include_self=False/periodic_boundary=False/N=2',
- time_used]
+ insert_row('GridN',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'include_self=False/periodic_boundary=False/N=2',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['GridN',
- 'OneEndConnector',
- f'{size}x{size}',
- 'build_coo',
- 'include_self=False/periodic_boundary=False/N=2',
- time_used]
+ insert_row('GridN',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'include_self=False/periodic_boundary=False/N=2',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['GridN',
- 'OneEndConnector',
- f'{size}x{size}',
- 'build_csr',
- 'include_self=False/periodic_boundary=False/N=2',
- time_used]
+ insert_row('GridN',
+ 'OneEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'include_self=False/periodic_boundary=False/N=2',
+ time_used)
class TwoEndConnector(unittest.TestCase):
@@ -179,32 +187,32 @@ def test_fixed_prob(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedProb',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_mat',
- 'prob=0.1',
- time_used]
+ insert_row('FixedProb',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'prob=0.1',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedProb',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_coo',
- 'prob=0.1',
- time_used]
+ insert_row('FixedProb',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'prob=0.1',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedProb',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_csr',
- 'prob=0.1',
- time_used]
+ insert_row('FixedProb',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'prob=0.1',
+ time_used)
for size in size_diff:
print('FixedProb:', size)
@@ -214,32 +222,32 @@ def test_fixed_prob(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedProb',
- 'TwoEndConnector',
- f'{size[0]}x{size[1]}',
- 'build_mat',
- 'prob=0.1',
- time_used]
+ insert_row('FixedProb',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_mat',
+ 'prob=0.1',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedProb',
- 'TwoEndConnector',
- f'{size[0]}x{size[1]}',
- 'build_coo',
- 'prob=0.1',
- time_used]
+ insert_row('FixedProb',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_coo',
+ 'prob=0.1',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedProb',
- 'TwoEndConnector',
- f'{size[0]}x{size[1]}',
- 'build_csr',
- 'prob=0.1',
- time_used]
+ insert_row('FixedProb',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_csr',
+ 'prob=0.1',
+ time_used)
def test_fixed_pre_num(self):
print()
@@ -251,32 +259,32 @@ def test_fixed_pre_num(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedPreNum',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_mat',
- 'pre_num=10',
- time_used]
+ insert_row('FixedPreNum',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'pre_num=10',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedPreNum',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_coo',
- 'pre_num=10',
- time_used]
+ insert_row('FixedPreNum',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'pre_num=10',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedPreNum',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_csr',
- 'pre_num=10',
- time_used]
+ insert_row('FixedPreNum',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'pre_num=10',
+ time_used)
for size in size_diff:
print('FixedPreNum:', size)
@@ -286,32 +294,32 @@ def test_fixed_pre_num(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedPreNum',
- 'TwoEndConnector',
- f'{size[0]}x{size[1]}',
- 'build_mat',
- 'pre_num=10',
- time_used]
+ insert_row('FixedPreNum',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_mat',
+ 'pre_num=10',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedPreNum',
- 'TwoEndConnector',
- f'{size[0]}x{size[1]}',
- 'build_coo',
- 'pre_num=10',
- time_used]
+ insert_row('FixedPreNum',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_coo',
+ 'pre_num=10',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedPreNum',
- 'TwoEndConnector',
- f'{size[0]}x{size[1]}',
- 'build_csr',
- 'pre_num=10',
- time_used]
+ insert_row('FixedPreNum',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_csr',
+ 'pre_num=10',
+ time_used)
def test_fixed_post_num(self):
print()
@@ -323,32 +331,32 @@ def test_fixed_post_num(self):
start = time.time()
mat = conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedPreNum',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_mat',
- 'num=10',
- time_used]
+ insert_row('FixedPreNum',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'num=10',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedPreNum',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_coo',
- 'num=10',
- time_used]
+ insert_row('FixedPreNum',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'num=10',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedPreNum',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_csr',
- 'num=10',
- time_used]
+ insert_row('FixedPreNum',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'num=10',
+ time_used)
for size in size_diff:
print('FixedPostNum:', size)
@@ -358,32 +366,32 @@ def test_fixed_post_num(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedPreNum',
- 'TwoEndConnector',
- f'{size[0]}x{size[1]}',
- 'build_mat',
- 'pre_num=10',
- time_used]
+ insert_row('FixedPreNum',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_mat',
+ 'pre_num=10',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedPreNum',
- 'TwoEndConnector',
- f'{size[0]}x{size[1]}',
- 'build_coo',
- 'pre_num=10',
- time_used]
+ insert_row('FixedPreNum',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_coo',
+ 'pre_num=10',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['FixedPreNum',
- 'TwoEndConnector',
- f'{size[0]}x{size[1]}',
- 'build_csr',
- 'pre_num=10',
- time_used]
+ insert_row('FixedPreNum',
+ 'TwoEndConnector',
+ f'{size[0]}x{size[1]}',
+ 'build_csr',
+ 'pre_num=10',
+ time_used)
def test_prob_dist(self):
print()
@@ -395,32 +403,32 @@ def test_prob_dist(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['ProbDist',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_mat',
- 'prob=0.5',
- time_used]
+ insert_row('ProbDist',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'prob=0.5',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['ProbDist',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_coo',
- 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True',
- time_used]
+ insert_row('ProbDist',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['ProbDist',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_csr',
- 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True',
- time_used]
+ insert_row('ProbDist',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True',
+ time_used)
def test_small_world(self):
print()
@@ -432,32 +440,32 @@ def test_small_world(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['SmallWorld',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_mat',
- 'num_neighbor=2/prob=0.5/include_self=False',
- time_used]
+ insert_row('SmallWorld',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'num_neighbor=2/prob=0.5/include_self=False',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['SmallWorld',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_coo',
- 'num_neighbor=2/prob=0.5/include_self=False',
- time_used]
+ insert_row('SmallWorld',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'num_neighbor=2/prob=0.5/include_self=False',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['SmallWorld',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_csr',
- 'num_neighbor=2/prob=0.5/include_self=False',
- time_used]
+ insert_row('SmallWorld',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'num_neighbor=2/prob=0.5/include_self=False',
+ time_used)
def test_scale_free_ba(self):
print()
@@ -469,32 +477,32 @@ def test_scale_free_ba(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['ScaleFreeBA',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_mat',
- 'm=2',
- time_used]
+ insert_row('ScaleFreeBA',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'm=2',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['ScaleFreeBA',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_coo',
- 'm=2',
- time_used]
+ insert_row('ScaleFreeBA',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'm=2',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['ScaleFreeBA',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_csr',
- 'm=2',
- time_used]
+ insert_row('ScaleFreeBA',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'm=2',
+ time_used)
def test_scale_free_ba_dual(self):
print()
@@ -506,32 +514,32 @@ def test_scale_free_ba_dual(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['ScaleFreeBADual',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_mat',
- 'm1=2/m2=3/p=0.4',
- time_used]
+ insert_row('ScaleFreeBADual',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'm1=2/m2=3/p=0.4',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['ScaleFreeBADual',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_coo',
- 'm1=2/m2=3/p=0.4',
- time_used]
+ insert_row('ScaleFreeBADual',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'm1=2/m2=3/p=0.4',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['ScaleFreeBADual',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_csr',
- 'm1=2/m2=3/p=0.4',
- time_used]
+ insert_row('ScaleFreeBADual',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'm1=2/m2=3/p=0.4',
+ time_used)
def test_power_law(self):
print()
@@ -543,32 +551,32 @@ def test_power_law(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['PowerLaw',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_mat',
- 'm=3/p=0.4',
- time_used]
+ insert_row('PowerLaw',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ 'm=3/p=0.4',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['PowerLaw',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_coo',
- 'm=3/p=0.4',
- time_used]
+ insert_row('PowerLaw',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ 'm=3/p=0.4',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['PowerLaw',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_csr',
- 'm=3/p=0.4',
- time_used]
+ insert_row('PowerLaw',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ 'm=3/p=0.4',
+ time_used)
def test_one2one(self):
print()
@@ -580,32 +588,32 @@ def test_one2one(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['One2One',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_mat',
- '',
- time_used]
+ insert_row('One2One',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ '',
+ time_used)
start = time.time()
conn.require(bp.connect.COO)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['One2One',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_coo',
- '',
- time_used]
+ insert_row('One2One',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_coo',
+ '',
+ time_used)
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['One2One',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_csr',
- '',
- time_used]
+ insert_row('One2One',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ '',
+ time_used)
def test_all2all(self):
print()
@@ -617,12 +625,12 @@ def test_all2all(self):
start = time.time()
conn.require(bp.connect.CONN_MAT)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['All2All',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_mat',
- '',
- time_used]
+ insert_row('All2All',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_mat',
+ '',
+ time_used)
# start = time.time()
# conn.require(bp.connect.COO)
@@ -637,12 +645,12 @@ def test_all2all(self):
start = time.time()
conn.require(bp.connect.CSR)
time_used = get_ms(time.time() - start)
- df.loc[len(df)] = ['All2All',
- 'TwoEndConnector',
- f'{size}x{size}',
- 'build_csr',
- '',
- time_used]
+ insert_row('All2All',
+ 'TwoEndConnector',
+ f'{size}x{size}',
+ 'build_csr',
+ '',
+ time_used)
class TestSave(unittest.TestCase):
From 0df03a59c682f8ff06151c3bc7e967020532be06 Mon Sep 17 00:00:00 2001
From: He Sichao <1310722434@qq.com>
Date: Thu, 29 Jun 2023 20:20:23 +0800
Subject: [PATCH 010/326] Update test_all_time.py
---
brainpy/_src/connect/tests/test_all_time.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py
index 2252e3a82..5d6a7996c 100644
--- a/brainpy/_src/connect/tests/test_all_time.py
+++ b/brainpy/_src/connect/tests/test_all_time.py
@@ -655,5 +655,8 @@ def test_all2all(self):
class TestSave(unittest.TestCase):
def test_save(self):
- df.to_csv('connector_time_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv',
- index=False)
+ try:
+ df.to_csv('connector_time_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv',
+ index=False)
+ except (NameError, UnboundLocalError):
+ print('No pandas installed, skip test.')
From ba43e537db671e019b562505f739b78956261e0a Mon Sep 17 00:00:00 2001
From: Routhleck <1310722434@qq.com>
Date: Wed, 5 Jul 2023 20:59:14 +0800
Subject: [PATCH 011/326] Optimized ScaleFreeBA, ScaleFreeBADual, PowerLaw and
ProbDist
Optimized ScaleFreeBA, ScaleFreeBADual, PowerLaw and ProbDist by preallocating repeated_nodes with numpy.array and using numba
---
brainpy/_src/connect/random_conn.py | 334 ++++++++++++++++++++++++----
1 file changed, 295 insertions(+), 39 deletions(-)
diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py
index b4cb5b21a..e9d2fcfae 100644
--- a/brainpy/_src/connect/random_conn.py
+++ b/brainpy/_src/connect/random_conn.py
@@ -4,6 +4,7 @@
from jax import vmap, jit, numpy as jnp
import numpy as np
+from numba import njit, prange
import brainpy.math as bm
from brainpy.errors import ConnectorError
@@ -683,11 +684,11 @@ def __repr__(self):
f'directed={self.directed}, '
f'seed={self.seed})')
- def build_conn(self):
+ def build_mat(self, isOptimized=True):
assert self.pre_num == self.post_num
# seed
- self.seed = self.rng.randint(1, int(1e7))
+ self.rng = np.random.RandomState(self.seed)
numba_seed(self.seed)
num_node = self.pre_num
@@ -700,7 +701,31 @@ def build_conn(self):
# Target nodes for new edges
targets = list(range(self.m))
# List of existing nodes, with nodes repeated once for each adjacent edge
- repeated_nodes = []
+
+ if not isOptimized:
+ repeated_nodes = []
+ # Start adding the other n-m nodes. The first node is m.
+ source = self.m
+ while source < num_node:
+ # Add edges to m nodes from the source.
+ origins = [source] * self.m
+ conn[origins, targets] = True
+ if not self.directed:
+ conn[targets, origins] = True
+ # Add one node to the list for each new edge just created.
+ repeated_nodes.extend(targets)
+ # And the new node "source" has m edges to add to the list.
+ repeated_nodes.extend([source] * self.m)
+ # Now choose m unique nodes from the existing nodes
+ # Pick uniformly from repeated_nodes (preferential attachment)
+ targets = list(self._connect(np.asarray(repeated_nodes), self.m))
+ source += 1
+ return conn
+
+ # List of existing nodes, with nodes repeated once for each adjacent edge
+ # Preallocate repeated_nodes as a numpy array
+ repeated_nodes = np.empty(2 * num_node * self.m, dtype=int)
+ size_repeated_nodes = 0
# Start adding the other n-m nodes. The first node is m.
source = self.m
while source < num_node:
@@ -710,15 +735,17 @@ def build_conn(self):
if not self.directed:
conn[targets, origins] = True
# Add one node to the list for each new edge just created.
- repeated_nodes.extend(targets)
+ repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = targets
+ size_repeated_nodes += self.m
# And the new node "source" has m edges to add to the list.
- repeated_nodes.extend([source] * self.m)
+ repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = source
+ size_repeated_nodes += self.m
# Now choose m unique nodes from the existing nodes
# Pick uniformly from repeated_nodes (preferential attachment)
- targets = list(self._connect(np.asarray(repeated_nodes), self.m))
+ targets = list(self._connect(repeated_nodes[:size_repeated_nodes], self.m))
source += 1
- return 'mat', conn
+ return conn
class ScaleFreeBADual(TwoEndConnector):
@@ -773,10 +800,10 @@ def __repr__(self):
return (f'{self.__class__.__name__}(m1={self.m1}, m2={self.m2}, '
f'p={self.p}, directed={self.directed}, seed={self.seed})')
- def build_conn(self):
+ def build_mat(self, isOptimized=True):
assert self.pre_num == self.post_num
# seed
- self.seed = self.rng.randint(1, int(1e7))
+ self.rng = np.random.RandomState(self.seed)
numba_seed(self.seed)
num_node = self.pre_num
@@ -791,8 +818,38 @@ def build_conn(self):
# Add max(m1,m2) initial nodes (m0 in barabasi-speak)
conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE)
+
+ if not isOptimized:
+ # List of existing nodes, with nodes repeated once for each adjacent edge
+ repeated_nodes = []
+ # Start adding the remaining nodes.
+ source = max(self.m1, self.m2)
+ # Pick which m to use first time (m1 or m2)
+ m = self.m1 if self.rng.random() < self.p else self.m2
+ # Target nodes for new edges
+ targets = list(range(m))
+ while source < num_node:
+ # Add edges to m nodes from the source.
+ origins = [source] * m
+ conn[origins, targets] = True
+ if not self.directed:
+ conn[targets, origins] = True
+ # Add one node to the list for each new edge just created.
+ repeated_nodes.extend(targets)
+ # And the new node "source" has m edges to add to the list.
+ repeated_nodes.extend([source] * m)
+ # Pick which m to use next time (m1 or m2)
+ m = self.m1 if self.rng.random() < self.p else self.m2
+ # Now choose m unique nodes from the existing nodes
+ # Pick uniformly from repeated_nodes (preferential attachment)
+ targets = list(self._connect(np.asarray(repeated_nodes), m))
+ source += 1
+ return conn
+
# List of existing nodes, with nodes repeated once for each adjacent edge
- repeated_nodes = []
+ # Preallocate repeated_nodes as a numpy array
+ repeated_nodes = np.empty(2 * num_node * max(self.m1, self.m2), dtype=int)
+ size_repeated_nodes = 0
# Start adding the remaining nodes.
source = max(self.m1, self.m2)
# Pick which m to use first time (m1 or m2)
@@ -806,17 +863,19 @@ def build_conn(self):
if not self.directed:
conn[targets, origins] = True
# Add one node to the list for each new edge just created.
- repeated_nodes.extend(targets)
+ repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = targets
+ size_repeated_nodes += m
# And the new node "source" has m edges to add to the list.
- repeated_nodes.extend([source] * m)
+ repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = source
+ size_repeated_nodes += m
# Pick which m to use next time (m1 or m2)
m = self.m1 if self.rng.random() < self.p else self.m2
# Now choose m unique nodes from the existing nodes
# Pick uniformly from repeated_nodes (preferential attachment)
- targets = list(self._connect(np.asarray(repeated_nodes), m))
+ targets = list(self._connect(repeated_nodes[:size_repeated_nodes], m))
source += 1
- return 'mat', conn
+ return conn
class PowerLaw(TwoEndConnector):
@@ -886,51 +945,99 @@ def _random_subset(seq, m):
def __repr__(self):
return (f'{self.__class__.__name__}(m={self.m}, p={self.p}, directed={self.directed}, seed={self.seed})')
- def build_conn(self):
+ def build_mat(self, isOptimized=True):
assert self.pre_num == self.post_num
# seed
- self.seed = self.rng.randint(1, int(1e7))
+ self.rng = np.random.RandomState(self.seed)
numba_seed(self.seed)
num_node = self.pre_num
if self.m < 1 or num_node < self.m:
raise ConnectorError(f"Must have m>1 and m 1 else p.flatten() for p in pre_ids])
size = np.prod(pre_size)
+
for i in range(size):
pre_pos = np.asarray([p[i] for p in pre_ids])
pres, posts = f(pre_pos, pre_size=pre_size, post_size=post_size, n_dim=n_dim)
From 018fdcf24063266aa416311ab03d834bd6583a7d Mon Sep 17 00:00:00 2001
From: Routhleck <1310722434@qq.com>
Date: Wed, 5 Jul 2023 22:06:07 +0800
Subject: [PATCH 012/326] Test the result after optimized
---
brainpy/_src/connect/random_conn.py | 1 +
.../connect/tests/test_GaussianProb_opt.py | 74 ------
brainpy/_src/connect/tests/test_all_time.py | 6 +-
.../connect/tests/test_optimized_result.py | 237 ++++++++++++++++++
.../_src/connect/tests/test_random_conn.py | 2 +-
5 files changed, 243 insertions(+), 77 deletions(-)
delete mode 100644 brainpy/_src/connect/tests/test_GaussianProb_opt.py
create mode 100644 brainpy/_src/connect/tests/test_optimized_result.py
diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py
index e9d2fcfae..5c66e47c7 100644
--- a/brainpy/_src/connect/random_conn.py
+++ b/brainpy/_src/connect/random_conn.py
@@ -1305,6 +1305,7 @@ def _connect_4d(pre_pos, pre_size, post_size, n_dim):
self._connect_3d_jit = _connect_3d_jit
self._connect_4d_jit = _connect_4d_jit
+
def build_coo(self, isOptimized=True):
if len(self.pre_size) != len(self.post_size):
raise ValueError('The dimensions of shapes of two objects to establish connections should '
diff --git a/brainpy/_src/connect/tests/test_GaussianProb_opt.py b/brainpy/_src/connect/tests/test_GaussianProb_opt.py
deleted file mode 100644
index 53d3fa910..000000000
--- a/brainpy/_src/connect/tests/test_GaussianProb_opt.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import pytest
-
-import unittest
-
-import brainpy as bp
-
-from time import time
-
-
-def test_gaussian_prob1():
- conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=100)
-
- mat = conn.build_mat(isOptimized=True)
- time0 = time()
- mat1 = conn.build_mat(isOptimized=True)
- time_optimized = time() - time0
-
- time0 = time()
- mat2 = conn.build_mat(isOptimized=False)
- time_origin = time() - time0
-
- assert bp.math.array_equal(mat1, mat2)
- print()
- print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
-
-
-def test_gaussian_prob2():
- conn = bp.connect.GaussianProb(sigma=4, seed=123)(pre_size=(10, 10))
- mat = conn.build_mat(isOptimized=True)
- time0 = time()
- mat1 = conn.build_mat(isOptimized=True)
- time_optimized = time() - time0
-
- time0 = time()
- mat2 = conn.build_mat(isOptimized=False)
- time_origin = time() - time0
-
- assert bp.math.array_equal(mat1, mat2)
- print()
- print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
-
-
-def test_gaussian_prob3():
- conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123)(pre_size=(10, 10))
- mat = conn.build_mat(isOptimized=True)
- time0 = time()
- mat1 = conn.build_mat(isOptimized=True)
- time_optimized = time() - time0
-
- time0 = time()
- mat2 = conn.build_mat(isOptimized=False)
- time_origin = time() - time0
-
- assert bp.math.array_equal(mat1, mat2)
- print()
- print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
-
-
-def test_gaussian_prob4():
- conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123)(pre_size=(10, 10, 10))
- mat = conn.build_mat(isOptimized=True)
- time0 = time()
- mat1 = conn.build_mat(isOptimized=True)
- time_optimized = time() - time0
-
- time0 = time()
- mat2 = conn.build_mat(isOptimized=False)
- time_origin = time() - time0
-
- assert bp.math.array_equal(mat1, mat2)
- print()
- print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py
index 5d6a7996c..b634d6dbe 100644
--- a/brainpy/_src/connect/tests/test_all_time.py
+++ b/brainpy/_src/connect/tests/test_all_time.py
@@ -15,9 +15,11 @@
print('No pandas installed, skip test.')
# size_same = [100, 500, 2500, 12500, 25000, 37500, 50000]
-size_same = [100, 500, 2500, 12500]
-size_diff = [(10, 100), (100, 1000), (1000, 10000)]
+# size_same = [100, 500, 2500, 12500]
+# size_diff = [(10, 100), (100, 1000), (1000, 10000)]
+size_same = [100, 500, 2500]
+size_diff = [(10, 100), (100, 1000)]
def get_ms(value):
return round(value * 1000, 4)
diff --git a/brainpy/_src/connect/tests/test_optimized_result.py b/brainpy/_src/connect/tests/test_optimized_result.py
new file mode 100644
index 000000000..7afd03136
--- /dev/null
+++ b/brainpy/_src/connect/tests/test_optimized_result.py
@@ -0,0 +1,237 @@
+# -*- coding: utf-8 -*-
+from datetime import datetime
+
+import pytest
+
+import unittest
+
+import brainpy as bp
+
+from time import time
+
+try:
+ import pandas as pd
+
+ df = pd.DataFrame(
+ columns=['connector name', 'connect matrix size', 'build function', 'other parameter', 'time origin(ms)',
+ 'time optimized(ms)'])
+except (ImportError, ModuleNotFoundError):
+ print('No pandas installed, skip test.')
+
+# size_same = [100, 500, 2500, 12500, 25000, 37500, 50000]
+# size_same = [100, 500, 2500, 12500]
+size_same = [100, 500, 2500]
+
+def get_ms(value):
+ return round(value * 1000, 4)
+
+
+def insert_row(connector_name, connect_matrix_size, build_function, other_parameter, time_origin_used,
+ time_optimized_used):
+ try:
+ df.loc[len(df)] = [connector_name, connect_matrix_size, build_function, other_parameter, time_origin_used, time_optimized_used]
+ except (NameError, UnboundLocalError):
+ print('No pandas installed, skip test.')
+
+
+def test_GaussianProb1():
+ conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)
+ for size in size_same:
+ conn(pre_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ print()
+ print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
+ insert_row('GaussianProb',
+ f'{size}x{size}',
+ 'build_mat',
+ 'sigma=1 / include_self=False',
+ time_origin,
+ time_optimized)
+
+
+def test_GaussianProb2():
+ conn = bp.connect.GaussianProb(sigma=4, seed=123)
+ for size in size_same:
+ conn(pre_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ print()
+ print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
+ insert_row('GaussianProb',
+ f'{size}x{size}',
+ 'build_mat',
+ 'sigma=4',
+ time_origin,
+ time_optimized)
+
+
+def test_GaussianProb3():
+ conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123)
+ for size in size_same:
+ conn(pre_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ print()
+ print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
+ insert_row('GaussianProb',
+ f'{size}x{size}',
+ 'build_mat',
+ 'sigma=4 / periodic_boundary=True',
+ time_origin,
+ time_optimized)
+
+
+def testGaussianProb4():
+ conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123)
+ for size in size_same:
+ conn(pre_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ print()
+ print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
+ insert_row('GaussianProb',
+ f'{size}x{size}',
+ 'build_mat',
+ 'sigma=4 / periodic_boundary=True',
+ time_origin,
+ time_optimized)
+
+
+def test_ScaleFreeBA():
+ conn = bp.connect.ScaleFreeBA(m=2, seed=123)
+ for size in size_same:
+ conn(pre_size=size, post_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ insert_row('ScaleFreeBA',
+ f'{size}x{size}',
+ 'build_mat',
+ 'm=2',
+ time_origin,
+ time_optimized)
+
+
+def test_ScaleFreeBADual():
+ conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4, seed=123)
+ for size in size_same:
+ conn(pre_size=size, post_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ insert_row('ScaleFreeBADual',
+ f'{size}x{size}',
+ 'build_mat',
+ 'm1=2 / m2=3 / p=0.4',
+ time_origin,
+ time_optimized)
+
+
+def test_PowerLaw():
+ conn = bp.connect.PowerLaw(m=3, p=0.4, seed=123)
+ for size in size_same:
+ conn(pre_size=size, post_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ insert_row('PowerLaw',
+ f'{size}x{size}',
+ 'build_mat',
+ 'm=3 / p=0.4',
+ time_origin,
+ time_optimized)
+
+
+def test_ProbDist():
+ conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=123, include_self=True)
+ # for size in [1000, (100, 20), (4, 20, 20), (4, 3, 8, 5)]:
+ for size in [10000]:
+ conn(pre_size=size, post_size=size)
+ pre_ids1, post_ids1 = conn.build_coo(isOptimized=True)
+ time0 = time()
+ pre_ids1, post_ids1 = conn.build_coo(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ pre_ids2, post_ids2 = conn.build_coo(isOptimized=False)
+ time0 = time()
+ pre_ids2, post_ids2 = conn.build_coo(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ # assert (bp.math.array_equal(pre_ids1, pre_ids2) and bp.math.array_equal(post_ids1, post_ids2))
+ print()
+ print(f'time origin: {time_origin}\ntime optimized: {time_optimized}')
+ insert_row('ProbDist',
+ {size},
+ 'build_coo',
+ 'dist=1 / prob=0.5 / pre_ratio=0.3 / include_self=True',
+ time_origin,
+ time_optimized)
+
+
+def test_save():
+ try:
+ df.to_csv('opt_time_compare' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv',
+ index=False)
+ except (NameError, UnboundLocalError):
+ print('No pandas installed, skip test.')
\ No newline at end of file
diff --git a/brainpy/_src/connect/tests/test_random_conn.py b/brainpy/_src/connect/tests/test_random_conn.py
index de45a5ff0..195761548 100644
--- a/brainpy/_src/connect/tests/test_random_conn.py
+++ b/brainpy/_src/connect/tests/test_random_conn.py
@@ -180,7 +180,7 @@ def test_PowerLaw():
print('conn_mat', mat)
-def test_prob_dist():
+def test_ProbDist():
conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True)
for size in [100, (10, 20), (2, 10, 20), (2, 3, 4, 5)]:
conn(pre_size=size, post_size=size)
From 287df02112899a51b55158d88cecb110cc77956d Mon Sep 17 00:00:00 2001
From: He Sichao <1310722434@qq.com>
Date: Thu, 6 Jul 2023 09:36:30 +0800
Subject: [PATCH 013/326] Fix bug in connector's `require` function
---
brainpy/_src/connect/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/brainpy/_src/connect/base.py b/brainpy/_src/connect/base.py
index 858fc54a7..3a264d313 100644
--- a/brainpy/_src/connect/base.py
+++ b/brainpy/_src/connect/base.py
@@ -425,7 +425,7 @@ def require(self, *structures):
return bm.as_jax(self.build_coo()[0], dtype=IDX_DTYPE)
elif POST_IDS in structures and _has_coo_imp:
return bm.as_jax(self.build_coo()[1], dtype=IDX_DTYPE)
- elif COO in structures and not _has_coo_imp:
+ elif COO in structures and _has_coo_imp:
return bm.as_jax(self.build_coo(), dtype=IDX_DTYPE)
elif len(structures) == 2:
From 05a4a8690624d022f8f3c38c48b5c2a0a549390c Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sat, 8 Jul 2023 22:31:56 +0800
Subject: [PATCH 014/326] rewrite `brainpy.neurons`, `brainpy.synapses` with
new `brainpy.dyn` module
---
brainpy/__init__.py | 209 +-
brainpy/_add_deprecations.py | 102 +
brainpy/_src/_delay.py | 4 +-
.../highdim/tests/test_slow_points.py | 2 +-
brainpy/_src/checkpoints/tests/test_io.py | 4 +-
.../tests/test_random_conn_visualize.py | 2 -
brainpy/_src/context.py | 7 +-
brainpy/_src/delay.py | 48 +-
brainpy/_src/dnn/activations.py | 10 +-
brainpy/_src/dnn/base.py | 4 +-
brainpy/_src/dnn/conv.py | 9 +-
brainpy/_src/dnn/dropout.py | 4 -
brainpy/_src/dnn/interoperation_flax.py | 5 +-
brainpy/_src/dnn/linear.py | 6 +-
brainpy/_src/dyn/base.py | 182 --
brainpy/_src/dyn/channels/Ca.py | 329 +--
brainpy/_src/dyn/channels/IH.py | 22 +-
brainpy/_src/dyn/channels/K.py | 20 +-
brainpy/_src/dyn/channels/KCa.py | 18 +-
brainpy/_src/dyn/channels/Na.py | 6 +-
brainpy/_src/dyn/channels/base.py | 130 +-
brainpy/_src/dyn/channels/leaky.py | 2 +-
brainpy/_src/dyn/channels/tests/test_Ca.py | 51 +-
brainpy/_src/dyn/ions/__init__.py | 3 +
brainpy/_src/dyn/ions/base.py | 96 +
brainpy/_src/dyn/ions/ca.py | 317 +++
brainpy/_src/dyn/neurons/base.py | 53 +
brainpy/_src/dyn/neurons/hh.py | 200 +-
brainpy/_src/dyn/neurons/lif.py | 39 +-
brainpy/_src/dyn/neurons/tests/test_hh.py | 8 +-
brainpy/_src/dyn/others/common.py | 2 +-
brainpy/_src/dyn/{neurons => others}/input.py | 97 +-
.../noise_groups.py => dyn/others/noise.py} | 28 +-
.../{neurons => others}/tests/test_input.py | 2 +-
.../others}/tests/test_input_groups.py | 4 +
.../others}/tests/test_noise_groups.py | 3 +-
brainpy/_src/dyn/outs/__init__.py | 2 +
brainpy/_src/dyn/outs/base.py | 21 +
.../_src/dyn/{synapses => outs}/outputs.py | 17 +-
brainpy/_src/dyn/projections/__init__.py | 3 +
.../{projections.py => projections/aligns.py} | 69 +-
brainpy/_src/dyn/projections/others.py | 73 +
brainpy/_src/{ => dyn}/rates/__init__.py | 0
brainpy/_src/{ => dyn}/rates/populations.py | 7 +-
.../_src/{ => dyn}/rates/tests/test_rates.py | 2 +-
brainpy/_src/dyn/synapses/__init__.py | 3 +
.../{dynamics.py => abstract_models.py} | 393 +--
brainpy/_src/dyn/synapses/bio_models.py | 328 +++
.../{ => dyn}/synapses/delay_couplings.py | 2 +-
.../_src/{ => dyn}/synapses/gap_junction.py | 6 +-
.../synapses}/test_delay_couplings.py | 4 +
.../synapses}/test_gap_junction.py | 2 +
brainpy/_src/dyn/utils.py | 16 +
.../_src/{synapses_v2 => dynold}/__init__.py | 0
brainpy/_src/dynold/experimental/__init__.py | 0
.../experimental}/abstract_synapses.py | 2 +-
.../experimental}/base.py | 8 +-
.../experimental}/others.py | 4 +-
.../experimental}/syn_outs.py | 2 +-
.../experimental}/syn_plasticity.py | 4 +-
brainpy/_src/{ => dynold}/neurons/__init__.py | 2 -
.../{ => dynold}/neurons/biological_models.py | 336 +--
.../{ => dynold}/neurons/fractional_models.py | 9 +-
.../{ => dynold}/neurons/reduced_models.py | 1257 ++-------
.../neurons/tests/test_biological_neurons.py | 65 +-
.../neurons/tests/test_fractional_neurons.py | 8 +-
.../neurons/tests/test_reduced_neurons.py | 9 +-
.../_src/{ => dynold}/synapses/__init__.py | 6 +-
.../{ => dynold}/synapses/abstract_models.py | 531 ++--
brainpy/_src/dynold/synapses/base.py | 562 ++++
.../_src/dynold/synapses/biological_models.py | 414 +++
brainpy/_src/dynold/synapses/compat.py | 257 ++
.../{ => dynold}/synapses/learning_rules.py | 97 +-
.../synapses/tests/test_abstract_synapses.py | 126 +
.../tests/test_biological_synapses.py | 103 +
.../synapses/tests/test_learning_rule.py | 33 +
brainpy/_src/{ => dynold}/synouts/__init__.py | 0
.../_src/{ => dynold}/synouts/conductances.py | 15 +-
brainpy/_src/{ => dynold}/synouts/ions.py | 9 +-
.../_src/{ => dynold}/synplast/__init__.py | 0
.../synplast/short_term_plasticity.py | 32 +-
brainpy/_src/dynsys.py | 1476 +++--------
brainpy/_src/integrators/ode/exponential.py | 4 +-
.../ode/tests/test_ode_method_exp_euler.py | 2 +-
brainpy/_src/math/compat_numpy.py | 1 +
brainpy/_src/math/compat_pytorch.py | 1 -
brainpy/_src/math/delayvars.py | 12 +-
brainpy/_src/math/ndarray.py | 8 +
.../math/object_transform/tests/test_base.py | 8 +-
.../tests/test_circular_reference.py | 2 +-
.../object_transform/tests/test_collector.py | 4 +-
.../tests/test_namechecking.py | 2 +-
.../math/object_transform/tests/test_tools.py | 2 +-
brainpy/_src/math/sharding.py | 13 +-
brainpy/_src/mixin.py | 499 +++-
brainpy/_src/neurons/compat.py | 16 -
brainpy/_src/neurons/input_groups.py | 201 --
brainpy/_src/runners.py | 46 +-
brainpy/_src/synapses/biological_models.py | 587 -----
brainpy/_src/synapses/compat.py | 300 ---
.../synapses/tests/test_abstract_synapses.py | 85 -
.../tests/test_biological_synapses.py | 69 -
.../_src/synapses/tests/test_learning_rule.py | 20 -
brainpy/_src/synplast/long_term_plasticity.py | 1 -
brainpy/_src/tests/test_dynsys.py | 40 +
brainpy/_src/tests/test_mixin.py | 30 +
brainpy/_src/train/__init__.py | 3 +-
brainpy/_src/transform.py | 5 +-
brainpy/_src/typing_copy.py | 2273 +++++++++++++++++
brainpy/channels.py | 57 +-
brainpy/dyn/__init__.py | 2 +
brainpy/dyn/channels.py | 23 +-
brainpy/dyn/ions.py | 12 +
brainpy/dyn/neurons.py | 19 +-
brainpy/dyn/others.py | 19 +-
brainpy/dyn/outs.py | 8 +
brainpy/dyn/projections.py | 10 +-
brainpy/dyn/rates.py | 0
brainpy/dyn/synapses.py | 21 +-
brainpy/errors.py | 6 +
brainpy/experimental.py | 8 +-
brainpy/mixin.py | 12 +-
brainpy/neurons.py | 19 +-
brainpy/rates.py | 11 -
brainpy/synapses.py | 33 +
brainpy/synapses/__init__.py | 5 -
brainpy/synapses/dynamics.py | 25 -
brainpy/synapses/synouts.py | 10 -
brainpy/synapses/synplast.py | 6 -
brainpy/synouts.py | 10 +
brainpy/synplast.py | 6 +
131 files changed, 7085 insertions(+), 5814 deletions(-)
create mode 100644 brainpy/_add_deprecations.py
delete mode 100644 brainpy/_src/dyn/base.py
create mode 100644 brainpy/_src/dyn/ions/__init__.py
create mode 100644 brainpy/_src/dyn/ions/base.py
create mode 100644 brainpy/_src/dyn/ions/ca.py
create mode 100644 brainpy/_src/dyn/neurons/base.py
rename brainpy/_src/dyn/{neurons => others}/input.py (69%)
rename brainpy/_src/{neurons/noise_groups.py => dyn/others/noise.py} (68%)
rename brainpy/_src/dyn/{neurons => others}/tests/test_input.py (94%)
rename brainpy/_src/{neurons => dyn/others}/tests/test_input_groups.py (87%)
rename brainpy/_src/{neurons => dyn/others}/tests/test_noise_groups.py (88%)
create mode 100644 brainpy/_src/dyn/outs/__init__.py
create mode 100644 brainpy/_src/dyn/outs/base.py
rename brainpy/_src/dyn/{synapses => outs}/outputs.py (93%)
create mode 100644 brainpy/_src/dyn/projections/__init__.py
rename brainpy/_src/dyn/{projections.py => projections/aligns.py} (70%)
create mode 100644 brainpy/_src/dyn/projections/others.py
rename brainpy/_src/{ => dyn}/rates/__init__.py (100%)
rename brainpy/_src/{ => dyn}/rates/populations.py (99%)
rename brainpy/_src/{ => dyn}/rates/tests/test_rates.py (98%)
rename brainpy/_src/dyn/synapses/{dynamics.py => abstract_models.py} (61%)
create mode 100644 brainpy/_src/dyn/synapses/bio_models.py
rename brainpy/_src/{ => dyn}/synapses/delay_couplings.py (99%)
rename brainpy/_src/{ => dyn}/synapses/gap_junction.py (94%)
rename brainpy/_src/{synapses/tests => dyn/synapses}/test_delay_couplings.py (93%)
rename brainpy/_src/{synapses/tests => dyn/synapses}/test_gap_junction.py (93%)
create mode 100644 brainpy/_src/dyn/utils.py
rename brainpy/_src/{synapses_v2 => dynold}/__init__.py (100%)
create mode 100644 brainpy/_src/dynold/experimental/__init__.py
rename brainpy/_src/{synapses_v2 => dynold/experimental}/abstract_synapses.py (99%)
rename brainpy/_src/{synapses_v2 => dynold/experimental}/base.py (96%)
rename brainpy/_src/{synapses_v2 => dynold/experimental}/others.py (96%)
rename brainpy/_src/{synapses_v2 => dynold/experimental}/syn_outs.py (97%)
rename brainpy/_src/{synapses_v2 => dynold/experimental}/syn_plasticity.py (98%)
rename brainpy/_src/{ => dynold}/neurons/__init__.py (68%)
rename brainpy/_src/{ => dynold}/neurons/biological_models.py (71%)
rename brainpy/_src/{ => dynold}/neurons/fractional_models.py (98%)
rename brainpy/_src/{ => dynold}/neurons/reduced_models.py (61%)
rename brainpy/_src/{ => dynold}/neurons/tests/test_biological_neurons.py (75%)
rename brainpy/_src/{ => dynold}/neurons/tests/test_fractional_neurons.py (80%)
rename brainpy/_src/{ => dynold}/neurons/tests/test_reduced_neurons.py (92%)
rename brainpy/_src/{ => dynold}/synapses/__init__.py (53%)
rename brainpy/_src/{ => dynold}/synapses/abstract_models.py (65%)
create mode 100644 brainpy/_src/dynold/synapses/base.py
create mode 100644 brainpy/_src/dynold/synapses/biological_models.py
create mode 100644 brainpy/_src/dynold/synapses/compat.py
rename brainpy/_src/{ => dynold}/synapses/learning_rules.py (77%)
create mode 100644 brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py
create mode 100644 brainpy/_src/dynold/synapses/tests/test_biological_synapses.py
create mode 100644 brainpy/_src/dynold/synapses/tests/test_learning_rule.py
rename brainpy/_src/{ => dynold}/synouts/__init__.py (100%)
rename brainpy/_src/{ => dynold}/synouts/conductances.py (90%)
rename brainpy/_src/{ => dynold}/synouts/ions.py (94%)
rename brainpy/_src/{ => dynold}/synplast/__init__.py (100%)
rename brainpy/_src/{ => dynold}/synplast/short_term_plasticity.py (88%)
delete mode 100644 brainpy/_src/neurons/compat.py
delete mode 100644 brainpy/_src/neurons/input_groups.py
delete mode 100644 brainpy/_src/synapses/biological_models.py
delete mode 100644 brainpy/_src/synapses/compat.py
delete mode 100644 brainpy/_src/synapses/tests/test_abstract_synapses.py
delete mode 100644 brainpy/_src/synapses/tests/test_biological_synapses.py
delete mode 100644 brainpy/_src/synapses/tests/test_learning_rule.py
delete mode 100644 brainpy/_src/synplast/long_term_plasticity.py
create mode 100644 brainpy/_src/tests/test_dynsys.py
create mode 100644 brainpy/_src/tests/test_mixin.py
create mode 100644 brainpy/_src/typing_copy.py
create mode 100644 brainpy/dyn/ions.py
create mode 100644 brainpy/dyn/outs.py
create mode 100644 brainpy/dyn/rates.py
create mode 100644 brainpy/synapses.py
delete mode 100644 brainpy/synapses/__init__.py
delete mode 100644 brainpy/synapses/dynamics.py
delete mode 100644 brainpy/synapses/synouts.py
delete mode 100644 brainpy/synapses/synplast.py
create mode 100644 brainpy/synouts.py
create mode 100644 brainpy/synplast.py
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index c0344c962..d3c5f4e3e 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -1,27 +1,25 @@
# -*- coding: utf-8 -*-
-__version__ = "2.4.2"
+__version__ = "2.4.3"
# fundamental supporting modules
from brainpy import errors, check, tools
try:
import jaxlib
-
del jaxlib
except ModuleNotFoundError:
raise ModuleNotFoundError(tools.jaxlib_install_info) from None
-# Part 1: Math Foundation #
-# ------------------------- #
+# Part: Math Foundation #
+# ----------------------- #
# math foundation
from brainpy import math
from .math import BrainPyObject
-# Part 2: Toolbox #
-# ----------------- #
-
+# Part: Toolbox #
+# --------------- #
# modules of toolbox
from brainpy import (
connect, # synaptic connection
@@ -33,8 +31,9 @@
encoding, # encoding schema
checkpoints, # checkpoints
check, # error checking
+ mixin, # mixin classes
+ algorithms, # online or offline training algorithms
)
-from . import algorithms # online or offline training algorithms
# convenient alias
conn = connect
@@ -50,188 +49,90 @@
from brainpy._src.integrators.sde.generic import (sdeint as sdeint)
from brainpy._src.integrators.fde.generic import (fdeint as fdeint)
-# Part 3: Models #
-# ---------------- #
-from brainpy import (
- channels, # channel models
- neurons, # neuron groups
- synapses, # synapses
- rates, # rate models
- experimental,
-
- dnn, layers, # deep neural network module
- dyn, # dynamics module
- # delay, # delay module
-)
-
-from brainpy.synapses import (
- synouts, # synaptic output
- synplast, # synaptic plasticity
-)
+# Part: Models #
+# -------------- #
+# base classes
from brainpy._src.dynsys import (
DynamicalSystem as DynamicalSystem,
- Container as Container,
+ DynSysGroup as DynSysGroup, # collectors
Sequential as Sequential,
Network as Network,
- NeuGroup as NeuGroup,
- SynConn as SynConn,
- SynOut as SynOut,
- SynSTP as SynSTP,
- SynLTP as SynLTP,
- TwoEndConn as TwoEndConn,
- CondNeuGroup as CondNeuGroup,
- Channel as Channel
+ Dynamics as Dynamics, # dynamics
+ NeuDyn as NeuDyn,
+ SynDyn as SynDyn,
+ IonChaDyn as IonChaDyn,
+)
+DynamicalSystemNS = DynamicalSystem
+NeuGroup = NeuGroupNS = NeuDyn
+
+# building blocks
+from brainpy import (
+ dnn, layers, # module for dnn layers
+ dyn, # module for modeling dynamics
)
# shared parameters
-from brainpy._src.context import share
+from brainpy._src.context import (share as share)
from brainpy._src.dynsys import not_pass_shared
-# running
+
+# Part: Running #
+# --------------- #
from brainpy._src.runners import (DSRunner as DSRunner)
from brainpy._src.transform import (LoopOverTime as LoopOverTime, )
+from brainpy import (running as running)
-# DynamicalSystem base classes
-from brainpy._src.dynsys import (
- DynamicalSystemNS as DynamicalSystemNS,
- NeuGroupNS as NeuGroupNS,
- TwoEndConnNS as TwoEndConnNS,
-)
-from brainpy._src.synapses_v2.base import (SynOutNS as SynOutNS,
- SynSTPNS as SynSTPNS,
- SynConnNS as SynConnNS, )
-
-# Part 4: Training #
-# ------------------ #
+# Part: Training #
+# ---------------- #
from brainpy._src.train.base import (DSTrainer as DSTrainer, )
from brainpy._src.train.back_propagation import (BPTT as BPTT,
- BPFF as BPFF, )
+ BPFF as BPFF,)
from brainpy._src.train.online import (OnlineTrainer as OnlineTrainer,
ForceTrainer as ForceTrainer, )
from brainpy._src.train.offline import (OfflineTrainer as OfflineTrainer,
RidgeTrainer as RidgeTrainer, )
-# Part 6: Others #
-# ------------------ #
-from brainpy import running, testing, analysis
+# Part: Analysis #
+# ---------------- #
+from brainpy import (analysis as analysis)
+
+
+# Part: Others #
+# ---------------- #
+from brainpy import testing
from brainpy._src.visualization import (visualize as visualize)
-from brainpy._src import base, train
-# Part 7: Deprecations #
-# ---------------------- #
+# Part: Deprecations #
+# -------------------- #
+from brainpy._src import base, train
+from brainpy import (
+ channels, # channel models
+ neurons, # neuron groups
+ synapses, # synapses
+ rates, # rate models
+ experimental,
+ synouts, # synaptic output
+ synplast, # synaptic plasticity
+)
from brainpy._src import modes
from brainpy._src.math.object_transform.base import (Base as Base,
- ArrayCollector,
+ ArrayCollector as ArrayCollector,
Collector as Collector, )
# deprecated
-from brainpy._src import checking
-from brainpy._src.synapses import compat
-from brainpy._src.deprecations import deprecation_getattr2
+from brainpy._add_deprecations import deprecation_getattr2
__deprecations = {
+ 'Container': ('brainpy.Container', 'brainpy.DynSysGroup', DynSysGroup),
'optimizers': ('brainpy.optimizers', 'brainpy.optim', optim),
'TensorCollector': ('brainpy.TensorCollector', 'brainpy.ArrayCollector', ArrayCollector),
}
__getattr__ = deprecation_getattr2('brainpy', __deprecations)
-tools.__deprecations = {
- 'clear_name_cache': ('brainpy.tools.clear_name_cache', 'brainpy.math.clear_name_cache', math.clear_name_cache),
- 'checking': ('brainpy.tools.checking', 'brainpy.checking', checking),
-}
-tools.__getattr__ = deprecation_getattr2('brainpy.tools', tools.__deprecations)
-
-integrators.__deprecations = {
- 'Integrator': ('brainpy.integrators.Integrator', 'brainpy.Integrator', Integrator),
- 'odeint': ('brainpy.integrators.odeint', 'brainpy.odeint', odeint),
- 'sdeint': ('brainpy.integrators.sdeint', 'brainpy.sdeint', sdeint),
- 'fdeint': ('brainpy.integrators.fdeint', 'brainpy.fdeint', fdeint),
- 'IntegratorRunner': ('brainpy.integrators.IntegratorRunner', 'brainpy.IntegratorRunner', IntegratorRunner),
- 'JointEq': ('brainpy.integrators.JointEq', 'brainpy.JointEq', JointEq),
-}
-integrators.__getattr__ = deprecation_getattr2('brainpy.integrators', integrators.__deprecations)
-
-train.__deprecations = {
- 'DSTrainer': ('brainpy.train.DSTrainer', 'brainpy.DSTrainer', DSTrainer),
- 'BPTT': ('brainpy.train.BPTT', 'brainpy.BPTT', BPTT),
- 'BPFF': ('brainpy.train.BPFF', 'brainpy.BPFF', BPFF),
- 'OnlineTrainer': ('brainpy.train.OnlineTrainer', 'brainpy.OnlineTrainer', OnlineTrainer),
- 'ForceTrainer': ('brainpy.train.ForceTrainer', 'brainpy.ForceTrainer', ForceTrainer),
- 'OfflineTrainer': ('brainpy.train.OfflineTrainer', 'brainpy.OfflineTrainer', OfflineTrainer),
- 'RidgeTrainer': ('brainpy.train.RidgeTrainer', 'brainpy.RidgeTrainer', RidgeTrainer),
-}
-train.__getattr__ = deprecation_getattr2('brainpy.train', train.__deprecations)
-
-ode.__deprecations = {'odeint': ('brainpy.ode.odeint', 'brainpy.odeint', odeint)}
-ode.__getattr__ = deprecation_getattr2('brainpy.ode', ode.__deprecations)
-
-sde.__deprecations = {'sdeint': ('brainpy.sde.sdeint', 'brainpy.sdeint', sdeint)}
-sde.__getattr__ = deprecation_getattr2('brainpy.sde', sde.__deprecations)
-
-fde.__deprecations = {'fdeint': ('brainpy.fde.fdeint', 'brainpy.fdeint', fdeint)}
-fde.__getattr__ = deprecation_getattr2('brainpy.fde', sde.__deprecations)
-
-dyn.__deprecations = {
- # module
- # 'channels': ('brainpy.dyn.channels', 'brainpy.channels', channels),
- # 'neurons': ('brainpy.dyn.neurons', 'brainpy.neurons', neurons),
- 'rates': ('brainpy.dyn.rates', 'brainpy.rates', rates),
- # 'synapses': ('brainpy.dyn.synapses', 'brainpy.synapses', synapses),
- 'synouts': ('brainpy.dyn.synouts', 'brainpy.synapses', synouts),
- 'synplast': ('brainpy.dyn.synplast', 'brainpy.synapses', synplast),
-
- # models
- 'DynamicalSystem': ('brainpy.dyn.DynamicalSystem', 'brainpy.DynamicalSystem', DynamicalSystem),
- 'Container': ('brainpy.dyn.Container', 'brainpy.Container', Container),
- 'Sequential': ('brainpy.dyn.Sequential', 'brainpy.Sequential', Sequential),
- 'Network': ('brainpy.dyn.Network', 'brainpy.Network', Network),
- 'NeuGroup': ('brainpy.dyn.NeuGroup', 'brainpy.NeuGroup', NeuGroup),
- 'SynConn': ('brainpy.dyn.SynConn', 'brainpy.SynConn', SynConn),
- # 'SynOut': ('brainpy.dyn.SynOut', 'brainpy.SynOut', SynOut),
- 'SynLTP': ('brainpy.dyn.SynLTP', 'brainpy.SynLTP', SynLTP),
- 'SynSTP': ('brainpy.dyn.SynSTP', 'brainpy.SynSTP', SynSTP),
- 'TwoEndConn': ('brainpy.dyn.TwoEndConn', 'brainpy.TwoEndConn', TwoEndConn),
- 'CondNeuGroup': ('brainpy.dyn.CondNeuGroup', 'brainpy.CondNeuGroup', CondNeuGroup),
- 'Channel': ('brainpy.dyn.Channel', 'brainpy.Channel', Channel),
- 'LoopOverTime': ('brainpy.dyn.LoopOverTime', 'brainpy.LoopOverTime', LoopOverTime),
- 'DSRunner': ('brainpy.dyn.DSRunner', 'brainpy.DSRunner', DSRunner),
-
- # neurons
- 'HH': ('brainpy.dyn.HH', 'brainpy.neurons.HH', neurons.HH),
- 'MorrisLecar': ('brainpy.dyn.MorrisLecar', 'brainpy.neurons.MorrisLecar', neurons.MorrisLecar),
- 'PinskyRinzelModel': ('brainpy.dyn.PinskyRinzelModel', 'brainpy.neurons.PinskyRinzelModel',
- neurons.PinskyRinzelModel),
- 'FractionalFHR': ('brainpy.dyn.FractionalFHR', 'brainpy.neurons.FractionalFHR', neurons.FractionalFHR),
- 'FractionalIzhikevich': ('brainpy.dyn.FractionalIzhikevich', 'brainpy.neurons.FractionalIzhikevich',
- neurons.FractionalIzhikevich),
- 'LIF': ('brainpy.dyn.LIF', 'brainpy.neurons.LIF', neurons.LIF),
- 'ExpIF': ('brainpy.dyn.ExpIF', 'brainpy.neurons.ExpIF', neurons.ExpIF),
- 'AdExIF': ('brainpy.dyn.AdExIF', 'brainpy.neurons.AdExIF', neurons.AdExIF),
- 'QuaIF': ('brainpy.dyn.QuaIF', 'brainpy.neurons.QuaIF', neurons.QuaIF),
- 'AdQuaIF': ('brainpy.dyn.AdQuaIF', 'brainpy.neurons.AdQuaIF', neurons.AdQuaIF),
- 'GIF': ('brainpy.dyn.GIF', 'brainpy.neurons.GIF', neurons.GIF),
- 'Izhikevich': ('brainpy.dyn.Izhikevich', 'brainpy.neurons.Izhikevich', neurons.Izhikevich),
- 'HindmarshRose': ('brainpy.dyn.HindmarshRose', 'brainpy.neurons.HindmarshRose', neurons.HindmarshRose),
- 'FHN': ('brainpy.dyn.FHN', 'brainpy.neurons.FHN', neurons.FHN),
- 'SpikeTimeGroup': ('brainpy.dyn.SpikeTimeGroup', 'brainpy.neurons.SpikeTimeGroup', neurons.SpikeTimeGroup),
- 'PoissonGroup': ('brainpy.dyn.PoissonGroup', 'brainpy.neurons.PoissonGroup', neurons.PoissonGroup),
- 'OUProcess': ('brainpy.dyn.OUProcess', 'brainpy.neurons.OUProcess', neurons.OUProcess),
-
- # synapses
- 'DeltaSynapse': ('brainpy.dyn.DeltaSynapse', 'brainpy.synapses.Delta', compat.DeltaSynapse),
- 'ExpCUBA': ('brainpy.dyn.ExpCUBA', 'brainpy.synapses.Exponential', compat.ExpCUBA),
- 'ExpCOBA': ('brainpy.dyn.ExpCOBA', 'brainpy.synapses.Exponential', compat.ExpCOBA),
- 'DualExpCUBA': ('brainpy.dyn.DualExpCUBA', 'brainpy.synapses.DualExponential', compat.DualExpCUBA),
- 'DualExpCOBA': ('brainpy.dyn.DualExpCOBA', 'brainpy.synapses.DualExponential', compat.DualExpCOBA),
- 'AlphaCUBA': ('brainpy.dyn.AlphaCUBA', 'brainpy.synapses.Alpha', compat.AlphaCUBA),
- 'AlphaCOBA': ('brainpy.dyn.AlphaCOBA', 'brainpy.synapses.Alpha', compat.AlphaCOBA),
- # 'NMDA': ('brainpy.dyn.NMDA', 'brainpy.synapses.NMDA', compat.NMDA),
-}
-dyn.__getattr__ = deprecation_getattr2('brainpy.dyn', dyn.__deprecations)
+del deprecation_getattr2
-del deprecation_getattr2, checking, compat
diff --git a/brainpy/_add_deprecations.py b/brainpy/_add_deprecations.py
new file mode 100644
index 000000000..f2f387cff
--- /dev/null
+++ b/brainpy/_add_deprecations.py
@@ -0,0 +1,102 @@
+
+from ._src import checking, train, integrators
+from . import tools, math, integrators, dyn, neurons, synapses
+from .integrators import ode, fde, sde
+from brainpy._src.integrators.base import Integrator
+from brainpy._src.integrators.runner import IntegratorRunner
+from brainpy._src.integrators.joint_eq import JointEq
+from brainpy._src.integrators.ode.generic import odeint
+from brainpy._src.integrators.sde.generic import sdeint
+from brainpy._src.integrators.fde.generic import fdeint
+from brainpy._src.dynsys import (DynamicalSystem, DynSysGroup, Sequential, Network,
+ NeuDyn, Projection, IonChaDyn)
+from brainpy._src.runners import DSRunner
+from brainpy._src.deprecations import deprecation_getattr2
+
+tools.__deprecations = {
+ 'clear_name_cache': ('brainpy.tools.clear_name_cache', 'brainpy.math.clear_name_cache', math.clear_name_cache),
+ 'checking': ('brainpy.tools.checking', 'brainpy.checking', checking),
+}
+tools.__getattr__ = deprecation_getattr2('brainpy.tools', tools.__deprecations)
+
+integrators.__deprecations = {
+ 'Integrator': ('brainpy.integrators.Integrator', 'brainpy.Integrator', Integrator),
+ 'odeint': ('brainpy.integrators.odeint', 'brainpy.odeint', odeint),
+ 'sdeint': ('brainpy.integrators.sdeint', 'brainpy.sdeint', sdeint),
+ 'fdeint': ('brainpy.integrators.fdeint', 'brainpy.fdeint', fdeint),
+ 'IntegratorRunner': ('brainpy.integrators.IntegratorRunner', 'brainpy.IntegratorRunner', IntegratorRunner),
+ 'JointEq': ('brainpy.integrators.JointEq', 'brainpy.JointEq', JointEq),
+}
+integrators.__getattr__ = deprecation_getattr2('brainpy.integrators', integrators.__deprecations)
+
+train.__deprecations = {
+ 'DSTrainer': ('brainpy.train.DSTrainer', 'brainpy.DSTrainer', train.base.DSTrainer),
+ 'BPTT': ('brainpy.train.BPTT', 'brainpy.BPTT', train.back_propagation.BPTT),
+ 'BPFF': ('brainpy.train.BPFF', 'brainpy.BPFF', train.back_propagation.BPFF),
+ 'OnlineTrainer': ('brainpy.train.OnlineTrainer', 'brainpy.OnlineTrainer', train.online.OnlineTrainer),
+ 'ForceTrainer': ('brainpy.train.ForceTrainer', 'brainpy.ForceTrainer', train.online.ForceTrainer),
+ 'OfflineTrainer': ('brainpy.train.OfflineTrainer', 'brainpy.OfflineTrainer', train.offline.OfflineTrainer),
+ 'RidgeTrainer': ('brainpy.train.RidgeTrainer', 'brainpy.RidgeTrainer', train.offline.RidgeTrainer),
+}
+train.__getattr__ = deprecation_getattr2('brainpy.train', train.__deprecations)
+
+
+neurons.__deprecations = {
+ 'OUProcess': ('brainpy.neurons.OUProcess', 'brainpy.dyn.OUProcess', dyn.OUProcess),
+ 'Leaky': ('brainpy.neurons.Leaky', 'brainpy.dyn.Leaky', dyn.Leaky),
+ 'Integrator': ('brainpy.neurons.Integrator', 'brainpy.dyn.Integrator', dyn.Integrator),
+ 'InputGroup': ('brainpy.neurons.InputGroup', 'brainpy.dyn.InputGroup', dyn.InputGroup),
+ 'OutputGroup': ('brainpy.neurons.OutputGroup', 'brainpy.dyn.OutputGroup', dyn.OutputGroup),
+ 'SpikeTimeGroup': ('brainpy.neurons.SpikeTimeGroup', 'brainpy.dyn.SpikeTimeGroup', dyn.SpikeTimeGroup),
+ 'PoissonGroup': ('brainpy.neurons.PoissonGroup', 'brainpy.dyn.PoissonGroup', dyn.PoissonGroup),
+}
+neurons.__getattr__ = deprecation_getattr2('brainpy.neurons', neurons.__deprecations)
+
+
+synapses.__deprecations = {
+ 'PoissonInput': ('brainpy.synapses.PoissonInput', 'brainpy.dyn.PoissonInput', dyn.PoissonInput),
+}
+synapses.__getattr__ = deprecation_getattr2('brainpy.synapses', synapses.__deprecations)
+
+
+ode.__deprecations = {
+ 'odeint': ('brainpy.ode.odeint', 'brainpy.odeint', odeint)
+}
+ode.__getattr__ = deprecation_getattr2('brainpy.ode', ode.__deprecations)
+
+sde.__deprecations = {
+ 'sdeint': ('brainpy.sde.sdeint', 'brainpy.sdeint', sdeint)
+}
+sde.__getattr__ = deprecation_getattr2('brainpy.sde', sde.__deprecations)
+
+fde.__deprecations = {
+ 'fdeint': ('brainpy.fde.fdeint', 'brainpy.fdeint', fdeint)
+}
+fde.__getattr__ = deprecation_getattr2('brainpy.fde', sde.__deprecations)
+
+dyn.__deprecations = {
+ # models
+ 'DynamicalSystem': ('brainpy.dyn.DynamicalSystem', 'brainpy.DynamicalSystem', DynamicalSystem),
+ 'Container': ('brainpy.dyn.Container', 'brainpy.DynSysGroup', DynSysGroup),
+ 'Sequential': ('brainpy.dyn.Sequential', 'brainpy.Sequential', Sequential),
+ 'Network': ('brainpy.dyn.Network', 'brainpy.Network', Network),
+ 'NeuGroup': ('brainpy.dyn.NeuGroup', 'brainpy.NeuDyn', NeuDyn),
+ 'Channel': ('brainpy.dyn.Channel', 'brainpy.IonChaDyn', IonChaDyn),
+ 'DSRunner': ('brainpy.dyn.DSRunner', 'brainpy.DSRunner', DSRunner),
+
+ # synapses
+ 'SynConn': ('brainpy.dyn.SynConn', 'brainpy.synapses.SynConn', synapses.SynConn),
+ # 'SynLTP': ('brainpy.dyn.SynLTP', 'brainpy.synapses.SynLTP', synapses.SynLTP),
+ 'SynSTP': ('brainpy.dyn.SynSTP', 'brainpy.synapses.SynSTP', synapses._SynSTP),
+ 'TwoEndConn': ('brainpy.dyn.TwoEndConn', 'brainpy.synapses.TwoEndConn', synapses.TwoEndConn),
+ 'DeltaSynapse': ('brainpy.dyn.DeltaSynapse', 'brainpy.synapses.Delta', synapses.DeltaSynapse),
+ 'ExpCUBA': ('brainpy.dyn.ExpCUBA', 'brainpy.synapses.Exponential', synapses.ExpCUBA),
+ 'ExpCOBA': ('brainpy.dyn.ExpCOBA', 'brainpy.synapses.Exponential', synapses.ExpCOBA),
+ 'DualExpCUBA': ('brainpy.dyn.DualExpCUBA', 'brainpy.synapses.DualExponential', synapses.DualExpCUBA),
+ 'DualExpCOBA': ('brainpy.dyn.DualExpCOBA', 'brainpy.synapses.DualExponential', synapses.DualExpCOBA),
+ 'AlphaCUBA': ('brainpy.dyn.AlphaCUBA', 'brainpy.synapses.Alpha', synapses.AlphaCUBA),
+ 'AlphaCOBA': ('brainpy.dyn.AlphaCOBA', 'brainpy.synapses.Alpha', synapses.AlphaCOBA),
+}
+dyn.__getattr__ = deprecation_getattr2('brainpy.dyn', dyn.__deprecations)
+
+
diff --git a/brainpy/_src/_delay.py b/brainpy/_src/_delay.py
index b19ad850e..a646fd159 100644
--- a/brainpy/_src/_delay.py
+++ b/brainpy/_src/_delay.py
@@ -11,7 +11,7 @@
from brainpy import check
from brainpy import math as bm
-from brainpy._src.dynsys import DynamicalSystemNS
+from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE
from brainpy._src.context import share
@@ -21,7 +21,7 @@
]
-class Delay(DynamicalSystemNS):
+class Delay(DynamicalSystem):
"""Delay variable which has a fixed delay length.
The data in this delay variable is arranged as::
diff --git a/brainpy/_src/analysis/highdim/tests/test_slow_points.py b/brainpy/_src/analysis/highdim/tests/test_slow_points.py
index 3d3a1d141..f4151cb85 100644
--- a/brainpy/_src/analysis/highdim/tests/test_slow_points.py
+++ b/brainpy/_src/analysis/highdim/tests/test_slow_points.py
@@ -5,7 +5,7 @@
import brainpy.math as bm
-class HH(bp.NeuGroup):
+class HH(bp.NeuDyn):
def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03,
V_th=20., C=1.0, name=None):
super(HH, self).__init__(size=size, name=name)
diff --git a/brainpy/_src/checkpoints/tests/test_io.py b/brainpy/_src/checkpoints/tests/test_io.py
index 5abbe967e..f8ed80210 100644
--- a/brainpy/_src/checkpoints/tests/test_io.py
+++ b/brainpy/_src/checkpoints/tests/test_io.py
@@ -35,7 +35,7 @@ def __init__(self):
io2.a2 = io1.a
io2.b2 = io2.b
- self.net = bp.Container(io1, io2)
+ self.net = bp.DynSysGroup(io1, io2)
print(self.net.vars().keys())
print(self.net.vars().unique().keys())
@@ -115,7 +115,7 @@ def __init__(self):
io1 = IO1()
io2 = IO2()
- self.net = bp.Container(io1, io2)
+ self.net = bp.DynSysGroup(io1, io2)
print(self.net.vars().keys())
print(self.net.vars().unique().keys())
diff --git a/brainpy/_src/connect/tests/test_random_conn_visualize.py b/brainpy/_src/connect/tests/test_random_conn_visualize.py
index a79ca387f..9cd64821c 100644
--- a/brainpy/_src/connect/tests/test_random_conn_visualize.py
+++ b/brainpy/_src/connect/tests/test_random_conn_visualize.py
@@ -2,8 +2,6 @@
import pytest
-import unittest
-
import brainpy as bp
diff --git a/brainpy/_src/context.py b/brainpy/_src/context.py
index 24ace7f80..74d7b6961 100644
--- a/brainpy/_src/context.py
+++ b/brainpy/_src/context.py
@@ -4,10 +4,9 @@
This context defines all shared data used in all modules in a computation.
"""
-from typing import Any
-from typing import Union
+from typing import Any, Union
-from brainpy._src.dynsys import DynamicalSystemNS
+from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.math.environment import get_dt
from brainpy._src.tools.dicts import DotDict
@@ -16,7 +15,7 @@
]
-class _ShareContext(DynamicalSystemNS):
+class _ShareContext(DynamicalSystem):
def __init__(self):
super().__init__()
diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py
index 2f2681b79..d24248d8c 100644
--- a/brainpy/_src/delay.py
+++ b/brainpy/_src/delay.py
@@ -1,21 +1,20 @@
"""
Delay variable.
"""
+
import math
import numbers
-from typing import Union, Callable, Optional, Dict, Sequence
+from typing import Union, Dict, Callable, Optional
import jax
-from functools import partial
import jax.numpy as jnp
import numpy as np
-from jax.lax import stop_gradient
from brainpy import check
-from brainpy import math as bm, tools
+from brainpy import math as bm
from brainpy._src.context import share
-from brainpy._src.initialize import parameter, variable_
-from brainpy._src.dynsys import DynamicalSystemNS
+from brainpy._src.dynsys import DynamicalSystem
+from brainpy._src.initialize import variable_
from brainpy._src.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE
from brainpy._src.mixin import ParamDesc
from brainpy.check import jit_error
@@ -27,7 +26,7 @@
]
-class Delay(DynamicalSystemNS, ParamDesc):
+class Delay(DynamicalSystem, ParamDesc):
"""Base class for delay variables.
Args:
@@ -61,9 +60,9 @@ def __init__(
# delay method
if method is None:
- if self.mode.is_parent_of(bm.NonBatchingMode):
+ if self.mode.is_one_of(bm.NonBatchingMode, bm.BatchingMode):
method = ROTATE_UPDATE
- elif self.mode.is_parent_of(bm.TrainingMode):
+ elif self.mode.is_a(bm.TrainingMode):
method = CONCAT_UPDATE
else:
method = ROTATE_UPDATE
@@ -129,7 +128,7 @@ def retrieve(self, delay_step, *indices):
raise NotImplementedError()
-class _TargetDelay1(Delay):
+class VariableDelay2(Delay):
"""Delay variable which has a fixed delay length.
The data in this delay variable is arranged as::
@@ -170,7 +169,6 @@ def __init__(
# delay target
target: bm.Variable,
- sharding: Optional[Sequence[str]] = None,
# delay time
time: Optional[Union[int, float]] = None,
@@ -198,22 +196,15 @@ def __init__(
assert target.batch_axis is not None
# sharding
- if sharding is not None:
- if len(sharding) == target.ndim:
- sharding = list(sharding)
- elif len(sharding) + 1 == target.ndim and target.batch_axis is not None:
- sharding = list(sharding)
- sharding.insert(target.batch_axis, bm.sharding.BATCH_AXIS)
- else:
- raise ValueError('sharding axis names do not match the target dimension. ')
- self._target_axis_names = tuple(sharding)
- if sharding is not None:
- sharding = list(sharding)
+ sharding = None
+ if target.axis_names is not None:
+ sharding = list(target.axis_names)
sharding.insert(0, bm.sharding.TIME_AXIS)
- self._data_sharding = tuple(sharding)
+ sharding = tuple(sharding)
+ self.axis_names = sharding
# target
- self.target = bm.sharding.partition(target, self._target_axis_names)
+ self.target = target
# delay data
self._init = init
@@ -353,7 +344,7 @@ def retrieve(self, delay_step, *indices):
if self.method == ROTATE_UPDATE:
i = share.load('i')
delay_idx = (i + delay_step) % (self.max_length + 1)
- delay_idx = stop_gradient(delay_idx)
+ delay_idx = jax.lax.stop_gradient(delay_idx)
elif self.method == CONCAT_UPDATE:
delay_idx = delay_step
@@ -618,7 +609,7 @@ def retrieve(self, delay_step, *indices):
if self.method == ROTATE_UPDATE:
i = share.load('i')
delay_idx = (i + delay_step - 1) % self.max_length
- delay_idx = stop_gradient(delay_idx)
+ delay_idx = jax.lax.stop_gradient(delay_idx)
elif self.method == CONCAT_UPDATE:
delay_idx = delay_step
@@ -654,7 +645,8 @@ def update(
# update the delay data at the first position
elif self.method == CONCAT_UPDATE:
if self.max_length > 1:
- self.data.value = bm.vstack([latest_value, self.data[1:]])
+ latest_value = bm.expand_dims(latest_value, 0)
+ self.data.value = bm.concat([latest_value, self.data[1:]], axis=0)
else:
self.data[0] = latest_value
@@ -742,3 +734,5 @@ def update(
"""
self.target.value = latest_value
super().update(latest_value)
+
+
diff --git a/brainpy/_src/dnn/activations.py b/brainpy/_src/dnn/activations.py
index e9f342319..e7461b016 100644
--- a/brainpy/_src/dnn/activations.py
+++ b/brainpy/_src/dnn/activations.py
@@ -4,10 +4,12 @@
from brainpy.types import ArrayType
from .base import Layer
-__all__ = ['Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh',
- 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU',
- 'LogSigmoid', 'Softplus', 'Softshrink', 'PReLU', 'Softsign', 'Tanhshrink',
- 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax']
+__all__ = [
+ 'Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh',
+ 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU',
+ 'LogSigmoid', 'Softplus', 'Softshrink', 'PReLU', 'Softsign', 'Tanhshrink',
+ 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax'
+]
def _inplace(inp, val, inplace):
diff --git a/brainpy/_src/dnn/base.py b/brainpy/_src/dnn/base.py
index d82e1c178..af0b4e2fc 100644
--- a/brainpy/_src/dnn/base.py
+++ b/brainpy/_src/dnn/base.py
@@ -1,7 +1,7 @@
-from brainpy._src.dynsys import DynamicalSystemNS
+from brainpy._src.dynsys import DynamicalSystem
-class Layer(DynamicalSystemNS):
+class Layer(DynamicalSystem):
"""Base class for a layer of artificial neural network."""
def reset_state(self, *args, **kwargs):
diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py
index 566949579..4d3fe8366 100644
--- a/brainpy/_src/dnn/conv.py
+++ b/brainpy/_src/dnn/conv.py
@@ -4,7 +4,7 @@
from jax import lax
-from brainpy import math as bm, tools, check
+from brainpy import math as bm, tools
from brainpy._src.initialize import Initializer, XavierNormal, ZeroInit, parameter
from brainpy.types import ArrayType
from .base import Layer
@@ -81,6 +81,8 @@ class _GeneralConv(Layer):
The name of the object.
"""
+ supported_modes = (bm.TrainingMode, bm.BatchingMode)
+
def __init__(
self,
num_spatial_dims: int,
@@ -99,7 +101,6 @@ def __init__(
name: str = None,
):
super(_GeneralConv, self).__init__(name=name, mode=mode)
- check.is_subclass(self.mode, (bm.TrainingMode, bm.BatchingMode), self.name)
self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
@@ -462,6 +463,8 @@ def _check_input_dim(self, x):
class _GeneralConvTranspose(Layer):
+ supported_modes = (bm.TrainingMode, bm.BatchingMode)
+
def __init__(
self,
num_spatial_dims: int,
@@ -479,8 +482,6 @@ def __init__(
):
super().__init__(name=name, mode=mode)
- assert self.mode.is_parent_of(bm.TrainingMode, bm.BatchingMode)
-
self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
self.out_channels = out_channels
diff --git a/brainpy/_src/dnn/dropout.py b/brainpy/_src/dnn/dropout.py
index ddc2fc7ff..80dbafdd4 100644
--- a/brainpy/_src/dnn/dropout.py
+++ b/brainpy/_src/dnn/dropout.py
@@ -36,10 +36,6 @@ def __init__(
mode: bm.Mode = None,
name: str = None
):
- """
-
-
- """
super(Dropout, self).__init__(mode=mode, name=name)
self.prob = check.is_float(prob, min_bound=0., max_bound=1.)
diff --git a/brainpy/_src/dnn/interoperation_flax.py b/brainpy/_src/dnn/interoperation_flax.py
index 19d4c757a..b0c9c01ac 100644
--- a/brainpy/_src/dnn/interoperation_flax.py
+++ b/brainpy/_src/dnn/interoperation_flax.py
@@ -5,8 +5,9 @@
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from brainpy import math as bm
-from brainpy._src.dynsys import DynamicalSystemNS, DynamicalSystem
+from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.context import share
+from .base import Layer
try:
import flax # noqa
@@ -34,7 +35,7 @@ def _is_bp(a):
return isinstance(a, bm.Array)
-class FromFlax(DynamicalSystemNS):
+class FromFlax(Layer):
"""
Transform a Flax module as a BrainPy :py:class:`~.DynamicalSystem`.
diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py
index 39636562a..a5faccc10 100644
--- a/brainpy/_src/dnn/linear.py
+++ b/brainpy/_src/dnn/linear.py
@@ -316,7 +316,7 @@ def update(self, pre_val):
class MaskedLinear(Layer):
- r"""Synaptic matrix multiplication with dense computation.
+ r"""Synaptic matrix multiplication with masked dense computation.
It performs the computation of:
@@ -327,6 +327,10 @@ class MaskedLinear(Layer):
where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value,
:math:`M` the synaptic weight using a dense matrix.
+ >>> import brainpy as bp
+ >>> l = bp.dnn.MaskedLinear(bp.conn.FixedProb(0.1, pre=100, post=100),
+ >>> weight=0.1)
+
Args:
mask: TwoEndConnector. The connection.
weight: Synaptic weights. Can be a scalar, array, or callable function.
diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py
deleted file mode 100644
index 919ca9d39..000000000
--- a/brainpy/_src/dyn/base.py
+++ /dev/null
@@ -1,182 +0,0 @@
-from typing import Sequence, Union, Callable, Any, Optional, Dict
-
-import brainpy.math as bm
-from brainpy._src.dyn._docs import pneu_doc, dpneu_doc
-from brainpy._src.dynsys import NeuGroupNS, DynamicalSystemNS
-from brainpy._src.initialize.generic import parameter, variable_
-from brainpy._src.mixin import ParamDesc, ProjAutoDelay
-from brainpy.check import is_callable
-
-
-__all__ = [
- 'NeuDyn',
- 'SynDyn',
- 'SynOut',
-]
-
-
-class NeuDyn(NeuGroupNS, ProjAutoDelay):
- """Parallelizable Neuron Group.
-
- Args:
- {pneu}
- """
-
- def __init__(
- self,
- size: Union[int, Sequence[int]],
- sharding: Any = None,
- keep_size: bool = False,
- mode: bm.Mode = None,
- name: str = None,
- method: str = 'exp_auto'
- ):
- super().__init__(size=size,
- mode=mode,
- keep_size=keep_size,
- name=name)
-
- # axis names for parallelization
- self.sharding = sharding
-
- # integration method
- self.method = method
-
- # the before- / after-updates used for computing
- self.before_updates: Dict[str, Callable] = bm.node_dict()
- self.after_updates: Dict[str, Callable] = bm.node_dict()
-
- # outputs
- self.cur_inputs: Dict[str, SynOut] = bm.node_dict()
-
- def init_param(self, param, shape=None, sharding=None):
- """Initialize parameters.
-
- If ``sharding`` is provided and ``param`` is array, this function will
- partition the parameter across the default device mesh.
-
- See :py:func:`~.brainpy.math.sharding.device_mesh` for the mesh setting.
- """
- shape = self.varshape if shape is None else shape
- sharding = self.sharding if sharding is None else sharding
- return parameter(param,
- sizes=shape,
- allow_none=False,
- sharding=sharding)
-
- def init_variable(self, var_data, batch_or_mode, shape=None, sharding=None):
- """Initialize variables.
-
- If ``sharding`` is provided and ``var_data`` is array, this function will
- partition the variable across the default device mesh.
-
- See :py:func:`~.brainpy.math.sharding.device_mesh` for the mesh setting.
- """
- shape = self.varshape if shape is None else shape
- sharding = self.sharding if sharding is None else sharding
- return variable_(var_data,
- sizes=shape,
- batch_or_mode=batch_or_mode,
- axis_names=sharding,
- batch_axis_name=bm.sharding.BATCH_AXIS)
-
- def __call__(self, *args, **kwargs):
- # update ``before_updates``
- for model in tuple(self.before_updates.values()):
- model()
-
- # update the model self
- ret = super().__call__(*args, **kwargs)
-
- # update ``after_updates``
- for model in tuple(self.after_updates.values()):
- model(ret)
- return ret
-
-
-NeuDyn.__doc__ = NeuDyn.__doc__.format(pneu=pneu_doc)
-
-
-class GradNeuDyn(NeuDyn):
- """Differentiable and Parallelizable Neuron Group.
-
- Args:
- {pneu}
- {dpneu}
- """
-
- supported_modes = (bm.TrainingMode, bm.NonBatchingMode)
-
- def __init__(
- self,
- size: Union[int, Sequence[int]],
- sharding: Any = None,
- keep_size: bool = False,
- mode: Optional[bm.Mode] = None,
- name: Optional[str] = None,
- method: str = 'exp_auto',
-
- spk_fun: Callable = bm.surrogate.InvSquareGrad(),
- spk_type: Any = None,
- detach_spk: bool = False,
- ):
- super().__init__(size=size,
- mode=mode,
- keep_size=keep_size,
- name=name,
- sharding=sharding,
- method=method)
-
- self.spk_fun = is_callable(spk_fun)
- self.detach_spk = detach_spk
- self._spk_type = spk_type
-
- @property
- def spk_type(self):
- if self._spk_type is None:
- return bm.float_ if isinstance(self.mode, bm.TrainingMode) else bm.bool_
- else:
- return self._spk_type
-
-
-GradNeuDyn.__doc__ = GradNeuDyn.__doc__.format(pneu=pneu_doc, dpneu=dpneu_doc)
-
-
-class SynDyn(NeuDyn, ParamDesc):
- """Parallelizable synaptic dynamics.
-
- :py:class:`~.PSynDyn` is a subclass of :py:class:`~.ParamDesc`, because it uses
- the parameter description to describe the uniqueness of the synapse model.
- """
- pass
-
-
-class SynOut(DynamicalSystemNS, ParamDesc):
- def __init__(
- self,
- name: Optional[str] = None,
- ):
- super().__init__(name=name)
- self._conductance = None
-
- def bind_cond(self, conductance):
- self._conductance = conductance
-
- def unbind_cond(self):
- self._conductance = None
-
- def __call__(self, *args, **kwargs):
- if self._conductance is None:
- raise ValueError(f'Please first pack data at the current step using '
- f'".bind_cond(data)". {self}')
- ret = self.update(self._conductance, *args, **kwargs)
- return ret
-
-
-class HHTypeNeuLTC(NeuDyn):
- pass
-
-
-class HHTypeNeu(HHTypeNeuLTC):
- pass
-
diff --git a/brainpy/_src/dyn/channels/Ca.py b/brainpy/_src/dyn/channels/Ca.py
index 9b73c35a2..91c532910 100644
--- a/brainpy/_src/dyn/channels/Ca.py
+++ b/brainpy/_src/dyn/channels/Ca.py
@@ -8,21 +8,15 @@
from typing import Union, Callable
import brainpy.math as bm
-from brainpy._src.dynsys import Channel
-from brainpy._src.initialize import OneInit, Initializer, parameter, variable
+from brainpy._src.context import share
+from brainpy._src.dyn.ions.ca import CalciumDyna
+from brainpy._src.initialize import Initializer, parameter, variable
from brainpy._src.integrators.joint_eq import JointEq
from brainpy._src.integrators.ode.generic import odeint
from brainpy.types import Shape, ArrayType
-from .base import Calcium, CalciumChannel
+from .base import CalciumChannel
__all__ = [
- 'CalciumFixed',
- 'CalciumDyna',
- 'CalciumDetailed',
- 'CalciumFirstOrder',
-
- '_ICa_p2q_ss', '_ICa_p2q_markov',
-
'ICaN_IS2008',
'ICaT_HM1992',
@@ -34,309 +28,6 @@
]
-class CalciumFixed(Calcium):
- """Fixed Calcium dynamics.
-
- This calcium model has no dynamics. It holds fixed reversal
- potential :math:`E` and concentration :math:`C`.
- """
-
- def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- E: Union[float, ArrayType, Initializer, Callable] = 120.,
- C: Union[float, ArrayType, Initializer, Callable] = 2.4e-4,
- method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
- **channels
- ):
- super(CalciumFixed, self).__init__(size,
- keep_size=keep_size,
- method=method,
- name=name,
- mode=mode,
- **channels)
- self.E = parameter(E, self.varshape, allow_none=False)
- self.C = parameter(C, self.varshape, allow_none=False)
-
- def update(self, tdi, V):
- for node in self.implicit_nodes.values():
- node.update(tdi, V, self.C, self.E)
-
- def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None):
- C_Ca = self.C if C_Ca is None else C_Ca
- E_Ca = self.E if E_Ca is None else E_Ca
- for node in self.nodes(level=1, include_self=False).unique().subset(Channel).values():
- node.reset_state(V, C_Ca, E_Ca, batch_size=batch_size)
-
-
-class CalciumDyna(Calcium):
- """Calcium ion flow with dynamics.
-
- Parameters
- ----------
- size: int, tuple of int
- The ion size.
- keep_size: bool
- Keep the geometry size.
- C0: float, ArrayType, Initializer, Callable
- The Calcium concentration outside of membrane.
- T: float, ArrayType, Initializer, Callable
- The temperature.
- C_initializer: Initializer, Callable, ArrayType
- The initializer for Calcium concentration.
- method: str
- The numerical method.
- name: str
- The ion name.
- """
- R = 8.31441 # gas constant, J*mol-1*K-1
- F = 96.489 # the Faraday constant
-
- def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- C0: Union[float, ArrayType, Initializer, Callable] = 2.,
- T: Union[float, ArrayType, Initializer, Callable] = 36.,
- C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4),
- method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
- **channels
- ):
- super(CalciumDyna, self).__init__(size,
- keep_size=keep_size,
- method=method,
- name=name,
- mode=mode,
- **channels)
-
- # parameters
- self.C0 = parameter(C0, self.varshape, allow_none=False)
- self.T = parameter(T, self.varshape, allow_none=False) # temperature
- self._C_initializer = C_initializer
- self._constant = self.R / (2 * self.F) * (273.15 + self.T)
-
- # variables
- self.C = variable(C_initializer, self.mode, self.varshape) # Calcium concentration
- self.E = bm.Variable(self._reversal_potential(self.C),
- batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) # Reversal potential
-
- # function
- self.integral = odeint(self.derivative, method=method)
-
- def derivative(self, C, t, V):
- raise NotImplementedError
-
- def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None):
- self.C.value = variable(self._C_initializer, batch_size, self.varshape) if (C_Ca is None) else C_Ca
- self.E.value = self._reversal_potential(self.C)
- for node in self.nodes(level=1, include_self=False).unique().subset(Channel).values():
- node.reset(V, self.C, self.E, batch_size=batch_size)
-
- def update(self, tdi, V):
- for node in self.nodes(level=1, include_self=False).unique().subset(Channel).values():
- node.update(tdi, V, self.C.value, self.E.value)
- self.C.value = self.integral(self.C.value, tdi['t'], V, tdi['dt'])
- self.E.value = self._reversal_potential(self.C.value)
-
- def _reversal_potential(self, C):
- return self._constant * bm.log(self.C0 / C)
-
-
-class CalciumDetailed(CalciumDyna):
- r"""Dynamical Calcium model proposed.
-
- **1. The dynamics of intracellular** :math:`Ca^{2+}`
-
- The dynamics of intracellular :math:`Ca^{2+}` were determined by two contributions [1]_ :
-
- *(i) Influx of* :math:`Ca^{2+}` *due to Calcium currents*
-
- :math:`Ca^{2+}` ions enter through :math:`Ca^{2+}` channels and diffuse into the
- interior of the cell. Only the :math:`Ca^{2+}` concentration in a thin shell beneath
- the membrane was modeled. The influx of :math:`Ca^{2+}` into such a thin shell followed:
-
- .. math::
-
- [Ca]_{i}=-\frac{k}{2 F d} I_{Ca}
-
- where :math:`F=96489\, \mathrm{C\, mol^{-1}}` is the Faraday constant,
- :math:`d=1\, \mathrm{\mu m}` is the depth of the shell beneath the membrane,
- the unit conversion constant is :math:`k=0.1` for :math:`I_T` in
- :math:`\mathrm{\mu A/cm^{2}}` and :math:`[Ca]_{i}` in millimolar,
- and :math:`I_{Ca}` is the summation of all :math:`Ca^{2+}` currents.
-
- *(ii) Efflux of* :math:`Ca^{2+}` *due to an active pump*
-
- In a thin shell beneath the membrane, :math:`Ca^{2+}` retrieval usually consists of a
- combination of several processes, such as binding to :math:`Ca^{2+}` buffers, calcium
- efflux due to :math:`Ca^{2+}` ATPase pump activity and diffusion to neighboring shells.
- Only the :math:`Ca^{2+}` pump was modeled here. We adopted the following kinetic scheme:
-
- .. math::
-
- Ca _{i}^{2+}+ P \overset{c_1}{\underset{c_2}{\rightleftharpoons}} CaP \xrightarrow{c_3} P+ Ca _{0}^{2+}
-
- where P represents the :math:`Ca^{2+}` pump, CaP is an intermediate state,
- :math:`Ca _{ o }^{2+}` is the extracellular :math:`Ca^{2+}` concentration,
- and :math:`c_{1}, c_{2}` and :math:`c_{3}` are rate constants. :math:`Ca^{2+}`
- ions have a high affinity for the pump :math:`P`, whereas extrusion of
- :math:`Ca^{2+}` follows a slower process (Blaustein, 1988 ). Therefore,
- :math:`c_{3}` is low compared to :math:`c_{1}` and :math:`c_{2}` and the
- Michaelis-Menten approximation can be used for describing the kinetics of the pump.
- According to such a scheme, the kinetic equation for the :math:`Ca^{2+}` pump is:
-
- .. math::
-
- \frac{[Ca^{2+}]_{i}}{dt}=-\frac{K_{T}[Ca]_{i}}{[Ca]_{i}+K_{d}}
-
- where :math:`K_{T}=10^{-4}\, \mathrm{mM\, ms^{-1}}` is the product of :math:`c_{3}`
- with the total concentration of :math:`P` and :math:`K_{d}=c_{2} / c_{1}=10^{-4}\, \mathrm{mM}`
- is the dissociation constant, which can be interpreted here as the value of
- :math:`[Ca]_{i}` at which the pump is half activated (if :math:`[Ca]_{i} \ll K_{d}`
- then the efflux is negligible).
-
- **2.A simple first-order model**
-
- While, in (Bazhenov, et al., 1998) [2]_, the :math:`Ca^{2+}` dynamics is
- described by a simple first-order model,
-
- .. math::
-
- \frac{d\left[Ca^{2+}\right]_{i}}{d t}=-\frac{I_{Ca}}{z F d}+\frac{\left[Ca^{2+}\right]_{rest}-\left[C a^{2+}\right]_{i}}{\tau_{Ca}}
-
- where :math:`I_{Ca}` is the summation of all :math:`Ca ^{2+}` currents, :math:`d`
- is the thickness of the perimembrane "shell" in which calcium is able to affect
- membrane properties :math:`(1.\, \mathrm{\mu M})`, :math:`z=2` is the valence of the
- :math:`Ca ^{2+}` ion, :math:`F` is the Faraday constant, and :math:`\tau_{C a}` is
- the :math:`Ca ^{2+}` removal rate. The resting :math:`Ca ^{2+}` concentration was
- set to be :math:`\left[ Ca ^{2+}\right]_{\text {rest}}=.05\, \mathrm{\mu M}` .
-
- **3. The reversal potential**
-
- The reversal potential of calcium :math:`Ca ^{2+}` is calculated according to the
- Nernst equation:
-
- .. math::
-
- E = k'{RT \over 2F} log{[Ca^{2+}]_0 \over [Ca^{2+}]_i}
-
- where :math:`R=8.31441 \, \mathrm{J} /(\mathrm{mol}^{\circ} \mathrm{K})`,
- :math:`T=309.15^{\circ} \mathrm{K}`,
- :math:`F=96,489 \mathrm{C} / \mathrm{mol}`,
- and :math:`\left[\mathrm{Ca}^{2+}\right]_{0}=2 \mathrm{mM}`.
-
- Parameters
- ----------
- d : float
- The thickness of the peri-membrane "shell".
- F : float
- The Faraday constant. (:math:`C*mmol^{-1}`)
- tau : float
- The time constant of the :math:`Ca ^{2+}` removal rate. (ms)
- C_rest : float
- The resting :math:`Ca ^{2+}` concentration.
- C0 : float
- The :math:`Ca ^{2+}` concentration outside of the membrane.
- R : float
- The gas constant. (:math:` J*mol^{-1}*K^{-1}`)
-
- References
- ----------
-
- .. [1] Destexhe, Alain, Agnessa Babloyantz, and Terrence J. Sejnowski.
- "Ionic mechanisms for intrinsic slow oscillations in thalamic
- relay neurons." Biophysical journal 65, no. 4 (1993): 1538-1552.
- .. [2] Bazhenov, Maxim, Igor Timofeev, Mircea Steriade, and Terrence J.
- Sejnowski. "Cellular and network models for intrathalamic augmenting
- responses during 10-Hz stimulation." Journal of neurophysiology 79,
- no. 5 (1998): 2730-2748.
-
- """
-
- def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- T: Union[float, ArrayType, Initializer, Callable] = 36.,
- d: Union[float, ArrayType, Initializer, Callable] = 1.,
- C_rest: Union[float, ArrayType, Initializer, Callable] = 2.4e-4,
- tau: Union[float, ArrayType, Initializer, Callable] = 5.,
- C0: Union[float, ArrayType, Initializer, Callable] = 2.,
- C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4),
- method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
- **channels
- ):
- super(CalciumDetailed, self).__init__(size,
- keep_size=keep_size,
- method=method,
- name=name,
- T=T,
- C0=C0,
- C_initializer=C_initializer,
- mode=mode,
- **channels)
-
- # parameters
- self.d = parameter(d, self.varshape, allow_none=False)
- self.tau = parameter(tau, self.varshape, allow_none=False)
- self.C_rest = parameter(C_rest, self.varshape, allow_none=False)
-
- def derivative(self, C, t, V):
- ICa = self.current(V, C, self.E)
- drive = bm.maximum(- ICa / (2 * self.F * self.d), 0.)
- return drive + (self.C_rest - C) / self.tau
-
-
-class CalciumFirstOrder(CalciumDyna):
- r"""The first-order calcium concentration model.
-
- .. math::
-
- Ca' = -\alpha I_{Ca} + -\beta Ca
-
- """
-
- def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- T: Union[float, ArrayType, Initializer, Callable] = 36.,
- alpha: Union[float, ArrayType, Initializer, Callable] = 0.13,
- beta: Union[float, ArrayType, Initializer, Callable] = 0.075,
- C0: Union[float, ArrayType, Initializer, Callable] = 2.,
- C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4),
- method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
- **channels
- ):
- super(CalciumFirstOrder, self).__init__(size,
- keep_size=keep_size,
- method=method,
- name=name,
- T=T,
- C0=C0,
- C_initializer=C_initializer,
- mode=mode,
- **channels)
-
- # parameters
- self.alpha = parameter(alpha, self.varshape, allow_none=False)
- self.beta = parameter(beta, self.varshape, allow_none=False)
-
- def derivative(self, C, t, V):
- ICa = self.current(V, C, self.E)
- drive = bm.maximum(- self.alpha * ICa, 0.)
- return drive - self.beta * C
-
-
# -------------------------
@@ -407,8 +98,8 @@ def dp(self, p, t, V):
def dq(self, q, t, V):
return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V)
- def update(self, tdi, V, C_Ca, E_Ca):
- self.p.value, self.q.value = self.integral(self.p, self.q, tdi['t'], V, tdi['dt'])
+ def update(self, V, C_Ca, E_Ca):
+ self.p.value, self.q.value = self.integral(self.p, self.q, share['t'], V, share['dt'])
def current(self, V, C_Ca, E_Ca):
return self.g_max * self.p * self.p * self.q * (E_Ca - V)
@@ -500,8 +191,8 @@ def dp(self, p, t, V):
def dq(self, q, t, V):
return self.phi_q * (self.f_q_alpha(V) * (1 - q) - self.f_q_beta(V) * q)
- def update(self, tdi, V, C_Ca, E_Ca):
- self.p.value, self.q.value = self.integral(self.p, self.q, tdi['t'], V, tdi['dt'])
+ def update(self, V, C_Ca, E_Ca):
+ self.p.value, self.q.value = self.integral(self.p, self.q, share['t'], V, share['dt'])
def current(self, V, C_Ca, E_Ca):
return self.g_max * self.p * self.p * self.q * (E_Ca - V)
@@ -600,8 +291,8 @@ def derivative(self, p, t, V):
p_inf = 2.7 / (bm.exp(-(V + 55.) / 15.) + bm.exp((V + 55.) / 15.)) + 1.6
return self.phi * (phi_p - p) / p_inf
- def update(self, tdi, V, C_Ca, E_Ca):
- self.p.value = self.integral(self.p.value, tdi['t'], V, tdi['dt'])
+ def update(self, V, C_Ca, E_Ca):
+ self.p.value = self.integral(self.p.value, share['t'], V, share['dt'])
def current(self, V, C_Ca, E_Ca):
M = C_Ca / (C_Ca + 0.2)
diff --git a/brainpy/_src/dyn/channels/IH.py b/brainpy/_src/dyn/channels/IH.py
index e89763078..708723a3b 100644
--- a/brainpy/_src/dyn/channels/IH.py
+++ b/brainpy/_src/dyn/channels/IH.py
@@ -8,10 +8,12 @@
from typing import Union, Callable
import brainpy.math as bm
+from brainpy._src.context import share
+from brainpy._src.dyn.ions.base import Calcium
from brainpy._src.initialize import Initializer, parameter, variable
from brainpy._src.integrators import odeint, JointEq
from brainpy.types import Shape, ArrayType
-from .base import IhChannel, CalciumChannel, Calcium
+from .base import IhChannel, CalciumChannel
__all__ = [
'Ih_HM1992',
@@ -88,8 +90,8 @@ def reset_state(self, V, batch_size=None):
if batch_size is not None:
assert self.p.shape[0] == batch_size
- def update(self, tdi, V):
- self.p.value = self.integral(self.p.value, tdi['t'], V, tdi['dt'])
+ def update(self, V):
+ self.p.value = self.integral(self.p.value, share['t'], V, share['dt'])
def current(self, V):
return self.g_max * self.p * (self.E - V)
@@ -174,12 +176,10 @@ def __init__(
name: str = None,
mode: bm.Mode = None,
):
- # IhChannel.__init__(self, size, name=name, keep_size=keep_size)
- CalciumChannel.__init__(self,
- size,
- keep_size=keep_size,
- name=name,
- mode=mode)
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
# parameters
self.T = parameter(T, self.varshape, allow_none=False)
@@ -219,9 +219,9 @@ def dOL(self, OL, t, O, P1):
def dP1(self, P1, t, C_Ca):
return self.k1 * C_Ca ** 4 * (1 - P1) - self.k2 * P1
- def update(self, tdi, V, C_Ca, E_Ca):
+ def update(self, V, C_Ca, E_Ca):
self.O.value, self.OL.value, self.P1.value = self.integral(self.O.value, self.OL.value, self.P1.value,
- tdi['t'], V=V, C_Ca=C_Ca, dt=tdi['dt'])
+ share['t'], V=V, C_Ca=C_Ca, dt=share['dt'])
def current(self, V, C_Ca, E_Ca):
return self.g_max * (self.O + self.g_inc * self.OL) * (self.E - V)
diff --git a/brainpy/_src/dyn/channels/K.py b/brainpy/_src/dyn/channels/K.py
index f97ca5b27..93f19a95e 100644
--- a/brainpy/_src/dyn/channels/K.py
+++ b/brainpy/_src/dyn/channels/K.py
@@ -8,6 +8,7 @@
from typing import Union, Callable, Optional
import brainpy.math as bm
+from brainpy._src.context import share
from brainpy._src.initialize import Initializer, parameter, variable
from brainpy._src.integrators import odeint, JointEq
from brainpy.types import Shape, ArrayType
@@ -92,8 +93,8 @@ def __init__(
def derivative(self, p, t, V):
return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p)
- def update(self, tdi, V):
- self.p.value = self.integral(self.p.value, tdi['t'], V, tdi['dt'])
+ def update(self, V):
+ self.p.value = self.integral(self.p.value, share['t'], V, share['dt'])
def current(self, V):
return self.g_max * self.p ** 4 * (self.E - V)
@@ -415,9 +416,8 @@ def dp(self, p, t, V):
def dq(self, q, t, V):
return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V)
- def update(self, tdi, V):
- t, dt = tdi['t'], tdi['dt']
- self.p.value, self.q.value = self.integral(self.p.value, self.q.value, t, V, dt)
+ def update(self, V):
+ self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt'])
def current(self, V):
return self.g_max * self.p ** 4 * self.q * (self.E - V)
@@ -710,9 +710,8 @@ def dp(self, p, t, V):
def dq(self, q, t, V):
return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V)
- def update(self, tdi, V):
- t, dt = tdi['t'], tdi['dt']
- self.p.value, self.q.value = self.integral(self.p.value, self.q.value, t, V, dt)
+ def update(self, V):
+ self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt'])
def current(self, V):
return self.g_max * self.p * self.q * (self.E - V)
@@ -997,9 +996,8 @@ def __init__(
def dp(self, p, t, V):
return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V)
- def update(self, tdi, V):
- t, dt = tdi['t'], tdi['dt']
- self.p.value = self.integral(self.p.value, t, V, dt)
+ def update(self, V):
+ self.p.value = self.integral(self.p.value, share['t'], V, share['dt'])
def current(self, V):
return self.g_max * self.p * (self.E - V)
diff --git a/brainpy/_src/dyn/channels/KCa.py b/brainpy/_src/dyn/channels/KCa.py
index 016229d97..28c53e64f 100644
--- a/brainpy/_src/dyn/channels/KCa.py
+++ b/brainpy/_src/dyn/channels/KCa.py
@@ -8,11 +8,13 @@
from typing import Union, Callable
+from brainpy._src.context import share
import brainpy.math as bm
from brainpy._src.initialize import Initializer, parameter, variable
from brainpy._src.integrators.ode.generic import odeint
from brainpy.types import Shape, ArrayType
-from .base import Calcium, CalciumChannel, PotassiumChannel
+from .base import CalciumChannel, PotassiumChannel
+from brainpy._src.dyn.ions.base import Calcium
__all__ = [
'IAHP_De1994',
@@ -84,11 +86,10 @@ def __init__(
name: str = None,
mode: bm.Mode = None,
):
- CalciumChannel.__init__(self,
- size=size,
- keep_size=keep_size,
- name=name,
- mode=mode)
+ super().__init__(size=size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
# parameters
self.E = parameter(E, self.varshape, allow_none=False)
@@ -109,9 +110,8 @@ def dp(self, p, t, C_Ca):
C3 = C2 + self.beta
return self.phi * (C2 / C3 - p) * C3
- def update(self, tdi, V, C_Ca, E_Ca):
- t, dt = tdi['t'], tdi['dt']
- self.p.value = self.integral(self.p.value, t, C_Ca=C_Ca, dt=dt)
+ def update(self, V, C_Ca, E_Ca):
+ self.p.value = self.integral(self.p.value, share['t'], C_Ca=C_Ca, dt=share['dt'])
def current(self, V, C_Ca, E_Ca):
return self.g_max * self.p * self.p * (self.E - V)
diff --git a/brainpy/_src/dyn/channels/Na.py b/brainpy/_src/dyn/channels/Na.py
index 533af4057..d29189ae8 100644
--- a/brainpy/_src/dyn/channels/Na.py
+++ b/brainpy/_src/dyn/channels/Na.py
@@ -8,6 +8,7 @@
from typing import Union, Callable
import brainpy.math as bm
+from brainpy._src.context import share
from brainpy._src.initialize import Initializer, parameter, variable
from brainpy._src.integrators import odeint, JointEq
from brainpy.types import ArrayType, Shape
@@ -95,9 +96,8 @@ def dp(self, p, t, V):
def dq(self, q, t, V):
return self.phi * (self.f_q_alpha(V) * (1. - q) - self.f_q_beta(V) * q)
- def update(self, tdi, V):
- t, dt = tdi['t'], tdi['dt']
- p, q = self.integral(self.p, self.q, t, V, dt)
+ def update(self, V):
+ p, q = self.integral(self.p, self.q, share['t'], V, share['dt'])
self.p.value, self.q.value = p, q
def current(self, V):
diff --git a/brainpy/_src/dyn/channels/base.py b/brainpy/_src/dyn/channels/base.py
index cb908d7be..db2d9700d 100644
--- a/brainpy/_src/dyn/channels/base.py
+++ b/brainpy/_src/dyn/channels/base.py
@@ -1,51 +1,22 @@
# -*- coding: utf-8 -*-
-from typing import Union
-
-import brainpy.math as bm
-from brainpy._src.dynsys import Container, CondNeuGroup, Channel, check_master
-from brainpy.types import Shape
+from brainpy._src.dynsys import IonChaDyn
+from brainpy._src.mixin import TreeNode
+from brainpy._src.dyn.ions.base import Calcium
+from brainpy._src.dyn.neurons.hh import HHTypedNeuron
__all__ = [
- 'Ion', 'IonChannel',
-
- # ions
- 'Calcium',
-
- # ion channels
- 'IhChannel', 'CalciumChannel', 'SodiumChannel', 'PotassiumChannel', 'LeakyChannel',
+ 'IonChannel', 'IhChannel', 'CalciumChannel', 'SodiumChannel', 'PotassiumChannel', 'LeakyChannel',
]
-class Ion(Channel):
- """Base class for ions."""
-
- '''The type of the master object.'''
- master_type = CondNeuGroup
-
- def update(self, tdi, V):
- raise NotImplementedError('Must be implemented by the subclass.')
-
- def reset(self, V, batch_size=None):
- self.reset_state(V, batch_size)
-
- def reset_state(self, V, batch_size=None):
- raise NotImplementedError('Must be implemented by the subclass.')
-
- def current(self, V):
- raise NotImplementedError('Must be implemented by the subclass.')
-
- def __repr__(self):
- return f'{self.__class__.__name__}(size={self.size})'
-
-
-class IonChannel(Channel):
+class IonChannel(IonChaDyn, TreeNode):
"""Base class for ion channels."""
'''The type of the master object.'''
- master_type = CondNeuGroup
+ master_type = HHTypedNeuron
- def update(self, tdi, V):
+ def update(self, V):
raise NotImplementedError('Must be implemented by the subclass.')
def current(self, V):
@@ -57,102 +28,51 @@ def reset(self, V, batch_size=None):
def reset_state(self, V, batch_size=None):
raise NotImplementedError('Must be implemented by the subclass.')
- def __repr__(self):
- return f'{self.__class__.__name__}(size={self.size})'
-
-
-class Calcium(Ion, Container):
- """The brainpy_object calcium dynamics.
+ def clear_input(self):
+ pass
- Parameters
- ----------
- size: int, sequence of int
- The size of the simulation target.
- method: str
- The numerical integration method.
- name: str
- The name of the object.
- **channels
- The calcium dependent channels.
- """
-
- '''The type of the master object.'''
- master_type = CondNeuGroup
-
- """Reversal potential."""
- E: Union[float, bm.Variable, bm.Array]
-
- """Calcium concentration."""
- C: Union[float, bm.Variable, bm.Array]
-
- def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
- **channels
- ):
- Ion.__init__(self, size, keep_size=keep_size, mode=mode)
- Container.__init__(self, name=name, mode=mode, **channels)
- self.method = method
-
- def current(self, V, C_Ca=None, E_Ca=None):
- C_Ca = self.C if (C_Ca is None) else C_Ca
- E_Ca = self.E if (E_Ca is None) else E_Ca
- nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(Channel).values())
- check_master(type(self), *nodes)
-
- if len(nodes) == 0:
- return 0.
- else:
- current = nodes[0].current(V, C_Ca, E_Ca)
- for node in nodes[1:]:
- current += node.current(V, C_Ca, E_Ca)
- return current
-
- def register_implicit_nodes(self, *channels, **named_channels):
- check_master(type(self), *channels, **named_channels)
- super(Calcium, self).register_implicit_nodes(*channels, **named_channels)
+ def __repr__(self):
+ return f'{self.name}(size={self.size})'
class CalciumChannel(IonChannel):
"""Base class for Calcium ion channels."""
- '''The type of the master object.'''
master_type = Calcium
+ '''The type of the master object.'''
- def update(self, tdi, V, C_Ca, E_Ca):
+ def update(self, V, C_Ca, E_Ca):
raise NotImplementedError
def current(self, V, C_Ca, E_Ca):
raise NotImplementedError
- def reset(self, V, C_Ca, E_Ca, batch_size=None):
+ def reset(self, V, C_Ca, E_Ca, batch_size: int = None):
self.reset_state(V, C_Ca, E_Ca, batch_size)
- def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
+ def reset_state(self, V, C_Ca, E_Ca, batch_size: int = None):
raise NotImplementedError('Must be implemented by the subclass.')
class IhChannel(IonChannel):
"""Base class for Ih channel models."""
- master_type = CondNeuGroup
+ master_type = HHTypedNeuron
class PotassiumChannel(IonChannel):
- """Base class for potassium channel."""
+ """Base class for potassium channel dynamics."""
'''The type of the master object.'''
- master_type = CondNeuGroup
+ master_type = HHTypedNeuron
class LeakyChannel(IonChannel):
- """Base class for leaky channel."""
- master_type = CondNeuGroup
+ """Base class for leaky channel dynamics."""
+
+ master_type = HHTypedNeuron
class SodiumChannel(IonChannel):
- """Base class for sodium channel."""
- master_type = CondNeuGroup
+ """Base class for sodium channel dynamics."""
+
+ master_type = HHTypedNeuron
diff --git a/brainpy/_src/dyn/channels/leaky.py b/brainpy/_src/dyn/channels/leaky.py
index 9e3784dd2..5a6f1b5e1 100644
--- a/brainpy/_src/dyn/channels/leaky.py
+++ b/brainpy/_src/dyn/channels/leaky.py
@@ -52,7 +52,7 @@ def __init__(
def reset_state(self, V, batch_size=None):
pass
- def update(self, tdi, V):
+ def update(self, V):
pass
def current(self, V):
diff --git a/brainpy/_src/dyn/channels/tests/test_Ca.py b/brainpy/_src/dyn/channels/tests/test_Ca.py
index 3c08c9873..2ffe1a983 100644
--- a/brainpy/_src/dyn/channels/tests/test_Ca.py
+++ b/brainpy/_src/dyn/channels/tests/test_Ca.py
@@ -6,16 +6,17 @@
from absl.testing import parameterized
from brainpy._src.dyn.channels import Ca
+
class Test_Ca(parameterized.TestCase):
def test_Ca(self):
- bm.random.seed(1234)
class Neuron(bp.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
- self.Ca1 = Ca.CalciumFixed(size)
- self.Ca2 = Ca.CalciumDetailed(size)
- self.Ca3 = Ca.CalciumFirstOrder(size)
+ self.Ca1 = bp.dyn.CalciumFixed(size)
+ self.Ca2 = bp.dyn.CalciumDetailed(size)
+ self.Ca3 = bp.dyn.CalciumFirstOrder(size)
+ bm.random.seed(1234)
model = Neuron(1)
runner = bp.DSRunner(model,
monitors=['V', 'Ca2.C', 'Ca3.C'],
@@ -27,12 +28,13 @@ def __init__(self, size):
def test_ICaN_IS2008(self):
bm.random.seed(1234)
+
class Neuron(bp.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
- self.Ca = Ca.CalciumDetailed(size,
- ICa=Ca.ICaN_IS2008(size),
- )
+ self.Ca = bp.dyn.CalciumDetailed(size,
+ ICa=bp.dyn.ICaN_IS2008(size),
+ )
model = Neuron(1)
runner = bp.DSRunner(model,
@@ -44,12 +46,13 @@ def __init__(self, size):
def test_ICaT_HM1992(self):
bm.random.seed(1234)
+
class Neuron(bp.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
- self.Ca = Ca.CalciumDetailed(size,
- ICa=Ca.ICaT_HM1992(size),
- )
+ self.Ca = bp.dyn.CalciumDetailed(size,
+ ICa=bp.dyn.ICaT_HM1992(size),
+ )
model = Neuron(1)
runner = bp.DSRunner(model,
@@ -63,12 +66,13 @@ def __init__(self, size):
def test_ICaT_HP1992(self):
bm.random.seed(1234)
+
class Neuron(bp.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
- self.Ca = Ca.CalciumDetailed(size,
- ICa=Ca.ICaT_HP1992(size),
- )
+ self.Ca = bp.dyn.CalciumDetailed(size,
+ ICa=bp.dyn.ICaT_HP1992(size),
+ )
model = Neuron(1)
runner = bp.DSRunner(model,
@@ -82,12 +86,13 @@ def __init__(self, size):
def test_ICaHT_HM1992(self):
bm.random.seed(1234)
+
class Neuron(bp.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
- self.Ca = Ca.CalciumDetailed(size,
- ICa=Ca.ICaHT_HM1992(size),
- )
+ self.Ca = bp.dyn.CalciumDetailed(size,
+ ICa=bp.dyn.ICaHT_HM1992(size),
+ )
model = Neuron(1)
runner = bp.DSRunner(model,
@@ -101,12 +106,13 @@ def __init__(self, size):
def test_ICaHT_Re1993(self):
bm.random.seed(1234)
+
class Neuron(bp.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
- self.Ca = Ca.CalciumDetailed(size,
- ICa=Ca.ICaHT_Re1993(size),
- )
+ self.Ca = bp.dyn.CalciumDetailed(size,
+ ICa=bp.dyn.ICaHT_Re1993(size),
+ )
model = Neuron(1)
runner = bp.DSRunner(model,
@@ -120,12 +126,13 @@ def __init__(self, size):
def test_ICaL_IS2008(self):
bm.random.seed(1234)
+
class Neuron(bp.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
- self.Ca = Ca.CalciumDetailed(size,
- ICa=Ca.ICaL_IS2008(size),
- )
+ self.Ca = bp.dyn.CalciumDetailed(size,
+ ICa=bp.dyn.ICaL_IS2008(size),
+ )
model = Neuron(1)
runner = bp.DSRunner(model,
diff --git a/brainpy/_src/dyn/ions/__init__.py b/brainpy/_src/dyn/ions/__init__.py
new file mode 100644
index 000000000..d9d4e9c37
--- /dev/null
+++ b/brainpy/_src/dyn/ions/__init__.py
@@ -0,0 +1,3 @@
+
+from .base import *
+from .ca import *
diff --git a/brainpy/_src/dyn/ions/base.py b/brainpy/_src/dyn/ions/base.py
new file mode 100644
index 000000000..2b260c03c
--- /dev/null
+++ b/brainpy/_src/dyn/ions/base.py
@@ -0,0 +1,96 @@
+# -*- coding: utf-8 -*-
+
+from typing import Union
+
+import brainpy.math as bm
+from brainpy._src.dyn.neurons.hh import CondNeuGroup
+from brainpy._src.dynsys import IonChaDyn
+from brainpy._src.mixin import Container, TreeNode
+from brainpy.types import Shape
+
+__all__ = [
+ 'Ion',
+ 'Calcium',
+]
+
+
+class Ion(IonChaDyn, TreeNode):
+ """Base class for ions."""
+
+ '''The type of the master object.'''
+ master_type = CondNeuGroup
+
+ def update(self, V):
+ raise NotImplementedError('Must be implemented by the subclass.')
+
+ def reset(self, V, batch_size=None):
+ self.reset_state(V, batch_size)
+
+ def reset_state(self, V, batch_size=None):
+ raise NotImplementedError('Must be implemented by the subclass.')
+
+ def current(self, V):
+ raise NotImplementedError('Must be implemented by the subclass.')
+
+ def clear_input(self):
+ pass
+
+ def __repr__(self):
+ return f'{self.name}(size={self.size})'
+
+
+class Calcium(Ion, Container):
+ """The brainpy_object calcium dynamics.
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The size of the simulation target.
+ method: str
+ The numerical integration method.
+ name: str
+ The name of the object.
+ **channels
+ The calcium dependent channels.
+ """
+
+ '''The type of the master object.'''
+ master_type = CondNeuGroup
+
+ """Reversal potential."""
+ E: Union[float, bm.Variable, bm.Array]
+
+ """Calcium concentration."""
+ C: Union[float, bm.Variable, bm.Array]
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ **channels
+ ):
+ super().__init__(size, keep_size=keep_size, mode=mode, method=method, name=name)
+
+ self.children = bm.node_dict(self.format_elements(IonChaDyn, **channels))
+
+ def update(self, V):
+ for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values():
+ node.update(V, self.C, self.E)
+
+ def current(self, V, C_Ca=None, E_Ca=None):
+ C_Ca = self.C if (C_Ca is None) else C_Ca
+ E_Ca = self.E if (E_Ca is None) else E_Ca
+ nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values())
+
+ if len(nodes) == 0:
+ return 0.
+ else:
+ self.check_hierarchies(self.__class__, *nodes)
+ current = nodes[0].current(V, C_Ca, E_Ca)
+ for node in nodes[1:]:
+ current += node.current(V, C_Ca, E_Ca)
+ return current
+
diff --git a/brainpy/_src/dyn/ions/ca.py b/brainpy/_src/dyn/ions/ca.py
new file mode 100644
index 000000000..29a5b8a2e
--- /dev/null
+++ b/brainpy/_src/dyn/ions/ca.py
@@ -0,0 +1,317 @@
+# -*- coding: utf-8 -*-
+
+from typing import Union, Callable
+
+import brainpy.math as bm
+from brainpy._src.context import share
+from brainpy._src.dynsys import IonChaDyn
+from brainpy._src.initialize import OneInit, Initializer, parameter, variable
+from brainpy._src.integrators.ode.generic import odeint
+from brainpy.types import Shape, ArrayType
+from .base import Calcium
+
+__all__ = [
+ 'CalciumFixed',
+ 'CalciumDetailed',
+ 'CalciumFirstOrder',
+]
+
+
+class CalciumFixed(Calcium):
+ """Fixed Calcium dynamics.
+
+ This calcium model has no dynamics. It holds fixed reversal
+ potential :math:`E` and concentration :math:`C`.
+ """
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = 120.,
+ C: Union[float, ArrayType, Initializer, Callable] = 2.4e-4,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ **channels
+ ):
+ super(CalciumFixed, self).__init__(size,
+ keep_size=keep_size,
+ method=method,
+ name=name,
+ mode=mode,
+ **channels)
+ self.E = parameter(E, self.varshape, allow_none=False)
+ self.C = parameter(C, self.varshape, allow_none=False)
+
+ def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None):
+ C_Ca = self.C if C_Ca is None else C_Ca
+ E_Ca = self.E if E_Ca is None else E_Ca
+ for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values():
+ node.reset_state(V, C_Ca, E_Ca, batch_size=batch_size)
+
+
+class CalciumDyna(Calcium):
+ """Calcium ion flow with dynamics.
+
+ Parameters
+ ----------
+ size: int, tuple of int
+ The ion size.
+ keep_size: bool
+ Keep the geometry size.
+ C0: float, ArrayType, Initializer, Callable
+ The Calcium concentration outside of membrane.
+ T: float, ArrayType, Initializer, Callable
+ The temperature.
+ C_initializer: Initializer, Callable, ArrayType
+ The initializer for Calcium concentration.
+ method: str
+ The numerical method.
+ name: str
+ The ion name.
+ """
+ R = 8.31441 # gas constant, J*mol-1*K-1
+ F = 96.489 # the Faraday constant
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ C0: Union[float, ArrayType, Initializer, Callable] = 2.,
+ T: Union[float, ArrayType, Initializer, Callable] = 36.,
+ C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4),
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ **channels
+ ):
+ super(CalciumDyna, self).__init__(size,
+ keep_size=keep_size,
+ method=method,
+ name=name,
+ mode=mode,
+ **channels)
+
+ # parameters
+ self.C0 = parameter(C0, self.varshape, allow_none=False)
+ self.T = parameter(T, self.varshape, allow_none=False) # temperature
+ self._C_initializer = C_initializer
+ self._constant = self.R / (2 * self.F) * (273.15 + self.T)
+
+ # variables
+ self.C = variable(C_initializer, self.mode, self.varshape) # Calcium concentration
+ self.E = bm.Variable(self._reversal_potential(self.C),
+ batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) # Reversal potential
+
+ # function
+ self.integral = odeint(self.derivative, method=method)
+
+ def derivative(self, C, t, V):
+ raise NotImplementedError
+
+ def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None):
+ self.C.value = variable(self._C_initializer, batch_size, self.varshape) if (C_Ca is None) else C_Ca
+ self.E.value = self._reversal_potential(self.C)
+ for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values():
+ node.reset(V, self.C, self.E, batch_size=batch_size)
+
+ def update(self, V):
+ for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values():
+ node.update(V, self.C.value, self.E.value)
+ self.C.value = self.integral(self.C.value, share['t'], V, share['dt'])
+ self.E.value = self._reversal_potential(self.C.value)
+
+ def _reversal_potential(self, C):
+ return self._constant * bm.log(self.C0 / C)
+
+
+class CalciumDetailed(CalciumDyna):
+ r"""Dynamical Calcium model proposed.
+
+ **1. The dynamics of intracellular** :math:`Ca^{2+}`
+
+ The dynamics of intracellular :math:`Ca^{2+}` were determined by two contributions [1]_ :
+
+ *(i) Influx of* :math:`Ca^{2+}` *due to Calcium currents*
+
+ :math:`Ca^{2+}` ions enter through :math:`Ca^{2+}` channels and diffuse into the
+ interior of the cell. Only the :math:`Ca^{2+}` concentration in a thin shell beneath
+ the membrane was modeled. The influx of :math:`Ca^{2+}` into such a thin shell followed:
+
+ .. math::
+
+ [Ca]_{i}=-\frac{k}{2 F d} I_{Ca}
+
+ where :math:`F=96489\, \mathrm{C\, mol^{-1}}` is the Faraday constant,
+ :math:`d=1\, \mathrm{\mu m}` is the depth of the shell beneath the membrane,
+ the unit conversion constant is :math:`k=0.1` for :math:`I_T` in
+ :math:`\mathrm{\mu A/cm^{2}}` and :math:`[Ca]_{i}` in millimolar,
+ and :math:`I_{Ca}` is the summation of all :math:`Ca^{2+}` currents.
+
+ *(ii) Efflux of* :math:`Ca^{2+}` *due to an active pump*
+
+ In a thin shell beneath the membrane, :math:`Ca^{2+}` retrieval usually consists of a
+ combination of several processes, such as binding to :math:`Ca^{2+}` buffers, calcium
+ efflux due to :math:`Ca^{2+}` ATPase pump activity and diffusion to neighboring shells.
+ Only the :math:`Ca^{2+}` pump was modeled here. We adopted the following kinetic scheme:
+
+ .. math::
+
+ Ca _{i}^{2+}+ P \overset{c_1}{\underset{c_2}{\rightleftharpoons}} CaP \xrightarrow{c_3} P+ Ca _{0}^{2+}
+
+ where P represents the :math:`Ca^{2+}` pump, CaP is an intermediate state,
+ :math:`Ca _{ o }^{2+}` is the extracellular :math:`Ca^{2+}` concentration,
+ and :math:`c_{1}, c_{2}` and :math:`c_{3}` are rate constants. :math:`Ca^{2+}`
+ ions have a high affinity for the pump :math:`P`, whereas extrusion of
+ :math:`Ca^{2+}` follows a slower process (Blaustein, 1988 ). Therefore,
+ :math:`c_{3}` is low compared to :math:`c_{1}` and :math:`c_{2}` and the
+ Michaelis-Menten approximation can be used for describing the kinetics of the pump.
+ According to such a scheme, the kinetic equation for the :math:`Ca^{2+}` pump is:
+
+ .. math::
+
+ \frac{[Ca^{2+}]_{i}}{dt}=-\frac{K_{T}[Ca]_{i}}{[Ca]_{i}+K_{d}}
+
+ where :math:`K_{T}=10^{-4}\, \mathrm{mM\, ms^{-1}}` is the product of :math:`c_{3}`
+ with the total concentration of :math:`P` and :math:`K_{d}=c_{2} / c_{1}=10^{-4}\, \mathrm{mM}`
+ is the dissociation constant, which can be interpreted here as the value of
+ :math:`[Ca]_{i}` at which the pump is half activated (if :math:`[Ca]_{i} \ll K_{d}`
+ then the efflux is negligible).
+
+ **2.A simple first-order model**
+
+ While, in (Bazhenov, et al., 1998) [2]_, the :math:`Ca^{2+}` dynamics is
+ described by a simple first-order model,
+
+ .. math::
+
+ \frac{d\left[Ca^{2+}\right]_{i}}{d t}=-\frac{I_{Ca}}{z F d}+\frac{\left[Ca^{2+}\right]_{rest}-\left[C a^{2+}\right]_{i}}{\tau_{Ca}}
+
+ where :math:`I_{Ca}` is the summation of all :math:`Ca ^{2+}` currents, :math:`d`
+ is the thickness of the perimembrane "shell" in which calcium is able to affect
+ membrane properties :math:`(1.\, \mathrm{\mu M})`, :math:`z=2` is the valence of the
+ :math:`Ca ^{2+}` ion, :math:`F` is the Faraday constant, and :math:`\tau_{C a}` is
+ the :math:`Ca ^{2+}` removal rate. The resting :math:`Ca ^{2+}` concentration was
+ set to be :math:`\left[ Ca ^{2+}\right]_{\text {rest}}=.05\, \mathrm{\mu M}` .
+
+ **3. The reversal potential**
+
+ The reversal potential of calcium :math:`Ca ^{2+}` is calculated according to the
+ Nernst equation:
+
+ .. math::
+
+ E = k'{RT \over 2F} log{[Ca^{2+}]_0 \over [Ca^{2+}]_i}
+
+ where :math:`R=8.31441 \, \mathrm{J} /(\mathrm{mol}^{\circ} \mathrm{K})`,
+ :math:`T=309.15^{\circ} \mathrm{K}`,
+ :math:`F=96,489 \mathrm{C} / \mathrm{mol}`,
+ and :math:`\left[\mathrm{Ca}^{2+}\right]_{0}=2 \mathrm{mM}`.
+
+ Parameters
+ ----------
+ d : float
+ The thickness of the peri-membrane "shell".
+ F : float
+ The Faraday constant. (:math:`C*mmol^{-1}`)
+ tau : float
+ The time constant of the :math:`Ca ^{2+}` removal rate. (ms)
+ C_rest : float
+ The resting :math:`Ca ^{2+}` concentration.
+ C0 : float
+ The :math:`Ca ^{2+}` concentration outside of the membrane.
+ R : float
+ The gas constant. (:math:` J*mol^{-1}*K^{-1}`)
+
+ References
+ ----------
+
+ .. [1] Destexhe, Alain, Agnessa Babloyantz, and Terrence J. Sejnowski.
+ "Ionic mechanisms for intrinsic slow oscillations in thalamic
+ relay neurons." Biophysical journal 65, no. 4 (1993): 1538-1552.
+ .. [2] Bazhenov, Maxim, Igor Timofeev, Mircea Steriade, and Terrence J.
+ Sejnowski. "Cellular and network models for intrathalamic augmenting
+ responses during 10-Hz stimulation." Journal of neurophysiology 79,
+ no. 5 (1998): 2730-2748.
+
+ """
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ T: Union[float, ArrayType, Initializer, Callable] = 36.,
+ d: Union[float, ArrayType, Initializer, Callable] = 1.,
+ C_rest: Union[float, ArrayType, Initializer, Callable] = 2.4e-4,
+ tau: Union[float, ArrayType, Initializer, Callable] = 5.,
+ C0: Union[float, ArrayType, Initializer, Callable] = 2.,
+ C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4),
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ **channels
+ ):
+ super(CalciumDetailed, self).__init__(size,
+ keep_size=keep_size,
+ method=method,
+ name=name,
+ T=T,
+ C0=C0,
+ C_initializer=C_initializer,
+ mode=mode,
+ **channels)
+
+ # parameters
+ self.d = parameter(d, self.varshape, allow_none=False)
+ self.tau = parameter(tau, self.varshape, allow_none=False)
+ self.C_rest = parameter(C_rest, self.varshape, allow_none=False)
+
+ def derivative(self, C, t, V):
+ ICa = self.current(V, C, self.E)
+ drive = bm.maximum(- ICa / (2 * self.F * self.d), 0.)
+ return drive + (self.C_rest - C) / self.tau
+
+
+class CalciumFirstOrder(CalciumDyna):
+ r"""The first-order calcium concentration model.
+
+ .. math::
+
+ Ca' = -\alpha I_{Ca} + -\beta Ca
+
+ """
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ T: Union[float, ArrayType, Initializer, Callable] = 36.,
+ alpha: Union[float, ArrayType, Initializer, Callable] = 0.13,
+ beta: Union[float, ArrayType, Initializer, Callable] = 0.075,
+ C0: Union[float, ArrayType, Initializer, Callable] = 2.,
+ C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4),
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ **channels
+ ):
+ super(CalciumFirstOrder, self).__init__(size,
+ keep_size=keep_size,
+ method=method,
+ name=name,
+ T=T,
+ C0=C0,
+ C_initializer=C_initializer,
+ mode=mode,
+ **channels)
+
+ # parameters
+ self.alpha = parameter(alpha, self.varshape, allow_none=False)
+ self.beta = parameter(beta, self.varshape, allow_none=False)
+
+ def derivative(self, C, t, V):
+ ICa = self.current(V, C, self.E)
+ drive = bm.maximum(- self.alpha * ICa, 0.)
+ return drive - self.beta * C
+
diff --git a/brainpy/_src/dyn/neurons/base.py b/brainpy/_src/dyn/neurons/base.py
new file mode 100644
index 000000000..bfe75c155
--- /dev/null
+++ b/brainpy/_src/dyn/neurons/base.py
@@ -0,0 +1,53 @@
+from typing import Sequence, Union, Callable, Any, Optional
+
+import brainpy.math as bm
+from brainpy._src.dyn._docs import pneu_doc, dpneu_doc
+from brainpy._src.dynsys import NeuDyn
+from brainpy.check import is_callable
+
+__all__ = ['GradNeuDyn']
+
+
+class GradNeuDyn(NeuDyn):
+ """Differentiable and Parallelizable Neuron Group.
+
+ Args:
+ {pneu}
+ {dpneu}
+ """
+
+ supported_modes = (bm.TrainingMode, bm.NonBatchingMode)
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ sharding: Any = None,
+ keep_size: bool = False,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ method: str = 'exp_auto',
+
+ spk_fun: Callable = bm.surrogate.InvSquareGrad(),
+ spk_type: Any = None,
+ detach_spk: bool = False,
+ ):
+ super().__init__(size=size,
+ mode=mode,
+ keep_size=keep_size,
+ name=name,
+ sharding=sharding,
+ method=method)
+
+ self.spk_fun = is_callable(spk_fun)
+ self.detach_spk = detach_spk
+ self._spk_type = spk_type
+
+ @property
+ def spk_type(self):
+ if self._spk_type is None:
+ return bm.float_ if isinstance(self.mode, bm.TrainingMode) else bm.bool_
+ else:
+ return self._spk_type
+
+
+GradNeuDyn.__doc__ = GradNeuDyn.__doc__.format(pneu=pneu_doc, dpneu=dpneu_doc)
diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py
index 2c38e7edb..cbfeb69fa 100644
--- a/brainpy/_src/dyn/neurons/hh.py
+++ b/brainpy/_src/dyn/neurons/hh.py
@@ -1,26 +1,187 @@
from functools import partial
-from typing import Union, Callable, Optional, Any, Sequence
+from typing import Any, Sequence
+from typing import Union, Callable, Optional
import brainpy.math as bm
from brainpy._src.context import share
-from brainpy._src.initialize import ZeroInit, OneInit, Uniform
-from brainpy._src.integrators import odeint, JointEq
+from brainpy._src.dynsys import NeuDyn, IonChaDyn, DynamicalSystem
+from brainpy._src.initialize import OneInit
+from brainpy._src.initialize import Uniform, variable_, noise as init_noise
+from brainpy._src.integrators import JointEq
+from brainpy._src.integrators import odeint, sdeint
+from brainpy._src.mixin import Container, TreeNode
+from brainpy._src.types import ArrayType
from brainpy.check import is_initializer
-from brainpy.types import Shape, ArrayType, Sharding
-from brainpy._src.dyn.base import HHTypeNeuLTC
-
+from brainpy.types import Shape
__all__ = [
+ 'CondNeuGroupLTC',
+ 'CondNeuGroup',
'HHLTC',
'HH',
'MorrisLecarLTC',
'MorrisLecar',
- 'WangBuzsakiModelLTC',
- 'WangBuzsakiModel'
+ 'WangBuzsakiHHLTC',
+ 'WangBuzsakiHH'
]
-class HHLTC(HHTypeNeuLTC):
+class HHTypedNeuron(NeuDyn, Container, TreeNode):
+ master_type = DynamicalSystem
+
+
+class CondNeuGroupLTC(HHTypedNeuron):
+ r"""Base class to model conductance-based neuron group.
+
+ The standard formulation for a conductance-based model is given as
+
+ .. math::
+
+ C_m {dV \over dt} = \sum_jg_j(E - V) + I_{ext}
+
+ where :math:`g_j=\bar{g}_{j} M^x N^y` is the channel conductance, :math:`E` is the
+ reversal potential, :math:`M` is the activation variable, and :math:`N` is the
+ inactivation variable.
+
+ :math:`M` and :math:`N` have the dynamics of
+
+ .. math::
+
+ {dx \over dt} = \phi_x {x_\infty (V) - x \over \tau_x(V)}
+
+ where :math:`x \in [M, N]`, :math:`\phi_x` is a temperature-dependent factor,
+ :math:`x_\infty` is the steady state, and :math:`\tau_x` is the time constant.
+ Equivalently, the above equation can be written as:
+
+ .. math::
+
+ \frac{d x}{d t}=\phi_{x}\left(\alpha_{x}(1-x)-\beta_{x} x\right)
+
+ where :math:`\alpha_{x}` and :math:`\beta_{x}` are rate constants.
+
+ .. versionadded:: 2.1.9
+ Model the conductance-based neuron model.
+
+ Parameters
+ ----------
+ size : int, sequence of int
+ The network size of this neuron group.
+ method: str
+ The numerical integration method.
+ name : optional, str
+ The neuron group name.
+
+ """
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ C: Union[float, ArrayType, Callable] = 1.,
+ A: Union[float, ArrayType, Callable] = 1e-3,
+ V_th: Union[float, ArrayType, Callable] = 0.,
+ V_initializer: Union[Callable, ArrayType] = Uniform(-70, -60.),
+ noise: Optional[Union[float, ArrayType, Callable]] = None,
+ method: str = 'exp_auto',
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ init_var: bool = True,
+ input_var: bool = True,
+ spk_type: Optional[type] = None,
+ **channels
+ ):
+ super().__init__(size, keep_size=keep_size, mode=mode, name=name, )
+
+ # attribute for ``Container``
+ self.children = bm.node_dict(self.format_elements(IonChaDyn, **channels))
+
+ # parameters for neurons
+ self.input_var = input_var
+ self.C = C
+ self.A = A
+ self.V_th = V_th
+ self.noise = init_noise(noise, self.varshape, num_vars=1)
+ self._V_initializer = V_initializer
+ self.spk_type = ((bm.float_ if isinstance(self.mode, bm.TrainingMode) else bm.bool)
+ if (spk_type is None) else spk_type)
+
+ # function
+ if self.noise is None:
+ self.integral = odeint(f=self.derivative, method=method)
+ else:
+ self.integral = sdeint(f=self.derivative, g=self.noise, method=method)
+
+ if init_var:
+ self.reset_state(self.mode)
+
+ def derivative(self, V, t, I):
+ # synapses
+ for out in self.cur_inputs.values():
+ I = I + out(V)
+ # channels
+ for ch in self.nodes(level=1, include_self=False).subset(IonChaDyn).unique().values():
+ I = I + ch.current(V)
+ return I / self.C
+
+ def reset_state(self, batch_size=None):
+ self.V = variable_(self._V_initializer, self.varshape, batch_size)
+ self.spike = variable_(partial(bm.zeros, dtype=self.spk_type), self.varshape, batch_size)
+ if self.input_var:
+ self.input = variable_(bm.zeros, self.varshape, batch_size)
+ for channel in self.nodes(level=1, include_self=False).subset(IonChaDyn).unique().values():
+ channel.reset_state(self.V.value, batch_size=batch_size)
+
+ def update(self, x=None):
+ # inputs
+ x = 0. if x is None else x
+ if self.input_var:
+ self.input += x
+ x = self.input.value
+ x = x * (1e-3 / self.A)
+
+ # integral
+ V = self.integral(self.V.value, share['t'], x, share['dt'])
+
+ # check whether the children channels have the correct parents.
+ channels = self.nodes(level=1, include_self=False).subset(IonChaDyn).unique()
+ self.check_hierarchies(self.__class__, **channels)
+
+ # update channels
+ for node in channels.values():
+ node.update(self.V.value)
+
+ # update variables
+ if self.spike.dtype == bool:
+ self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th)
+ else:
+ self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th).astype(self.spike.dtype)
+ self.V.value = V
+ return self.spike
+
+ def clear_input(self):
+ """Useful for monitoring inputs. """
+ if self.input_var:
+ self.input.value = bm.zeros_like(self.input)
+
+ def return_info(self):
+ return self.spike
+
+
+class CondNeuGroup(CondNeuGroupLTC):
+ def derivative(self, V, t, I):
+ for ch in self.nodes(level=1, include_self=False).subset(IonChaDyn).unique().values():
+ I = I + ch.current(V)
+ return I / self.C
+
+ def update(self, x=None):
+ # inputs
+ x = 0. if x is None else x
+ for out in self.cur_inputs.values():
+ x = x + out(self.V.value)
+ return super().update(x)
+
+
+class HHLTC(NeuDyn):
r"""Hodgkin–Huxley neuron model with liquid time constant.
**Model Descriptions**
@@ -191,6 +352,7 @@ class HHLTC(HHTypeNeuLTC):
frameworks for oscillatory network dynamics in neuroscience."
The Journal of Mathematical Neuroscience 6, no. 1 (2016): 1-92.
"""
+
def __init__(
self,
size: Union[int, Sequence[int]],
@@ -481,6 +643,7 @@ class HH(HHLTC):
frameworks for oscillatory network dynamics in neuroscience."
The Journal of Mathematical Neuroscience 6, no. 1 (2016): 1-92.
"""
+
def dV(self, V, t, m, h, n, I):
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
I_K = (self.gK * n ** 4.0) * (V - self.EK)
@@ -496,10 +659,10 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
-class MorrisLecarLTC(HHTypeNeuLTC):
+class MorrisLecarLTC(NeuDyn):
r"""The Morris-Lecar neuron model with liquid time constant.
**Model Descriptions**
@@ -572,6 +735,9 @@ class MorrisLecarLTC(HHTypeNeuLTC):
.. [5] http://www.scholarpedia.org/article/Morris-Lecar_model
.. [6] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model
"""
+
+ supported_modes = (bm.NonBatchingMode, bm.BatchingMode)
+
def __init__(
self,
size: Union[int, Sequence[int]],
@@ -748,6 +914,7 @@ class MorrisLecar(MorrisLecarLTC):
.. [5] http://www.scholarpedia.org/article/Morris-Lecar_model
.. [6] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model
"""
+
def dV(self, V, t, W, I):
M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2))
I_Ca = self.g_Ca * M_inf * (V - self.V_Ca)
@@ -770,10 +937,10 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
-class WangBuzsakiModelLTC(HHTypeNeuLTC):
+class WangBuzsakiHHLTC(NeuDyn):
r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model with liquid time constant.
Each model is described by a single compartment and obeys the current balance equation:
@@ -857,6 +1024,7 @@ class WangBuzsakiModelLTC(HHTypeNeuLTC):
neuroscience, 16(20), pp.6402-6413.
"""
+
def __init__(
self,
size: Union[int, Sequence[int]],
@@ -963,7 +1131,8 @@ def update(self, x=None):
def return_info(self):
return self.spike
-class WangBuzsakiModel(WangBuzsakiModelLTC):
+
+class WangBuzsakiHH(WangBuzsakiHHLTC):
r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model.
Each model is described by a single compartment and obeys the current balance equation:
@@ -1047,6 +1216,7 @@ class WangBuzsakiModel(WangBuzsakiModelLTC):
neuroscience, 16(20), pp.6402-6413.
"""
+
def m_inf(self, V):
alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)
beta = 4. * bm.exp(-(V + 60.) / 18.)
@@ -1079,4 +1249,4 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
\ No newline at end of file
+ return super().update(x)
diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py
index 84e463f93..62dceed37 100644
--- a/brainpy/_src/dyn/neurons/lif.py
+++ b/brainpy/_src/dyn/neurons/lif.py
@@ -10,7 +10,7 @@
from brainpy.check import is_initializer
from brainpy.types import Shape, ArrayType, Sharding
from brainpy._src.dyn._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc, if_doc
-from brainpy._src.dyn.base import GradNeuDyn
+from .base import GradNeuDyn
__all__ = [
'IF',
@@ -67,6 +67,7 @@ class IFLTC(GradNeuDyn):
%s
%s
"""
+
def __init__(
self,
size: Shape,
@@ -413,7 +414,7 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
LifRef.__doc__ = LifRefLTC.__doc__ % ('', lif_doc, pneu_doc, dpneu_doc, ref_doc)
@@ -517,6 +518,7 @@ class ExpIFLTC(GradNeuDyn):
conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919.
.. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire
"""
+
def __init__(
self,
size: Shape,
@@ -616,8 +618,7 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
-
+ return super().update(x)
class ExpIFRefLTC(ExpIFLTC):
@@ -740,7 +741,8 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
+
ExpIF.__doc__ = ExpIFLTC.__doc__ % ('')
ExpIFRefLTC.__doc__ = ExpIFLTC.__doc__ % (ltc_doc)
@@ -822,6 +824,7 @@ class AdExIFLTC(GradNeuDyn):
inputs." Journal of Neuroscience 23.37 (2003): 11628-11640.
.. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model
"""
+
def __init__(
self,
size: Shape,
@@ -949,7 +952,7 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
class AdExIFRefLTC(AdExIFLTC):
@@ -1092,13 +1095,15 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
+
AdExIF.__doc__ = AdExIFLTC.__doc__ % ('')
AdExIFRefLTC.__doc__ = AdExIFLTC.__doc__ % (ltc_doc)
AdExIFRef.__doc__ = AdExIFLTC.__doc__ % ('')
AdExIFLTC.__doc__ = AdExIFLTC.__doc__ % (ltc_doc)
+
class QuaIFLTC(GradNeuDyn):
r"""Quadratic Integrate-and-Fire neuron model %s.
@@ -1165,6 +1170,7 @@ class QuaIFLTC(GradNeuDyn):
(2000) Intrinsic dynamics in neuronal networks. I. Theory.
J. Neurophysiology 83, pp. 808–827.
"""
+
def __init__(
self,
size: Shape,
@@ -1262,7 +1268,7 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
class QuaIFRefLTC(QuaIFLTC):
@@ -1384,7 +1390,7 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
QuaIF.__doc__ = QuaIFLTC.__doc__ % ('')
@@ -1469,6 +1475,7 @@ class AdQuaIFLTC(GradNeuDyn):
nonlinear integrate-and-fire neurons." SIAM Journal on Applied
Mathematics 68, no. 4 (2008): 1045-1079.
"""
+
def __init__(
self,
size: Shape,
@@ -1592,7 +1599,7 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
class AdQuaIFRefLTC(AdQuaIFLTC):
@@ -1732,7 +1739,7 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
AdQuaIF.__doc__ = AdQuaIFLTC.__doc__ % ('')
@@ -1822,6 +1829,7 @@ class GifLTC(GradNeuDyn):
leaky integrate-and-fire models classify multiple neuron types."
Nature communications 9, no. 1 (2018): 1-15.
"""
+
def __init__(
self,
size: Shape,
@@ -1975,7 +1983,7 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
class GifRefLTC(GifLTC):
@@ -2142,7 +2150,7 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
Gif.__doc__ = GifLTC.__doc__ % ('')
@@ -2218,6 +2226,7 @@ class IzhikevichLTC(GradNeuDyn):
.. [2] Izhikevich, Eugene M. "Which model to use for cortical spiking neurons?."
IEEE transactions on neural networks 15.5 (2004): 1063-1070.
"""
+
def __init__(
self,
size: Shape,
@@ -2339,7 +2348,7 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
class IzhikevichRefLTC(IzhikevichLTC):
@@ -2475,7 +2484,7 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x += out(self.V.value)
- super().update(x)
+ return super().update(x)
Izhikevich.__doc__ = IzhikevichLTC.__doc__ % ('')
diff --git a/brainpy/_src/dyn/neurons/tests/test_hh.py b/brainpy/_src/dyn/neurons/tests/test_hh.py
index 2a9bd7a46..c49831579 100644
--- a/brainpy/_src/dyn/neurons/tests/test_hh.py
+++ b/brainpy/_src/dyn/neurons/tests/test_hh.py
@@ -96,7 +96,7 @@ def test_MorrisLecarLTC_batching_mode(self):
self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10))
def test_WangBuzsakiModel(self):
- model = hh.WangBuzsakiModel(size=1)
+ model = hh.WangBuzsakiHH(size=1)
runner = bp.DSRunner(model,
monitors=['V', 'n', 'h', 'spike'],
progress_bar=False)
@@ -107,7 +107,7 @@ def test_WangBuzsakiModel(self):
self.assertTupleEqual(runner.mon['spike'].shape, (100, 1))
def test_WangBuzsakiModel_batching_mode(self):
- model = hh.WangBuzsakiModel(size=10, mode=bm.batching_mode)
+ model = hh.WangBuzsakiHH(size=10, mode=bm.batching_mode)
runner = bp.DSRunner(model,
monitors=['V', 'n', 'h', 'spike'],
progress_bar=False)
@@ -118,7 +118,7 @@ def test_WangBuzsakiModel_batching_mode(self):
self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10))
def test_WangBuzsakiModelLTC(self):
- model = hh.WangBuzsakiModelLTC(size=1)
+ model = hh.WangBuzsakiHHLTC(size=1)
runner = bp.DSRunner(model,
monitors=['V', 'n', 'h', 'spike'],
progress_bar=False)
@@ -129,7 +129,7 @@ def test_WangBuzsakiModelLTC(self):
self.assertTupleEqual(runner.mon['spike'].shape, (100, 1))
def test_WangBuzsakiModelLTC_batching_mode(self):
- model = hh.WangBuzsakiModelLTC(size=10, mode=bm.batching_mode)
+ model = hh.WangBuzsakiHHLTC(size=10, mode=bm.batching_mode)
runner = bp.DSRunner(model,
monitors=['V', 'n', 'h', 'spike'],
progress_bar=False)
diff --git a/brainpy/_src/dyn/others/common.py b/brainpy/_src/dyn/others/common.py
index ef069d4ea..418cb6ad1 100644
--- a/brainpy/_src/dyn/others/common.py
+++ b/brainpy/_src/dyn/others/common.py
@@ -5,7 +5,7 @@
from brainpy._src import tools
from brainpy._src.context import share
from brainpy._src.dyn._docs import pneu_doc
-from brainpy._src.dyn.base import NeuDyn
+from brainpy._src.dynsys import NeuDyn
from brainpy._src.integrators import odeint
from brainpy.check import is_initializer
from brainpy.types import ArrayType
diff --git a/brainpy/_src/dyn/neurons/input.py b/brainpy/_src/dyn/others/input.py
similarity index 69%
rename from brainpy/_src/dyn/neurons/input.py
rename to brainpy/_src/dyn/others/input.py
index ebe440a33..041f8b59f 100644
--- a/brainpy/_src/dyn/neurons/input.py
+++ b/brainpy/_src/dyn/others/input.py
@@ -1,14 +1,18 @@
# -*- coding: utf-8 -*-
-from typing import Union, Sequence, Any
+from functools import partial
+from typing import Union, Sequence, Any, Optional, Callable
+import jax
import jax.numpy as jnp
+
+from brainpy import math as bm
from brainpy._src.context import share
-import brainpy.math as bm
-from brainpy._src.initialize import Initializer, parameter, variable_
+from brainpy._src.dyn.utils import get_spk_type
+from brainpy._src.dynsys import NeuDyn
+from brainpy._src.initialize import parameter, variable_
from brainpy._src.mixin import ReturnInfo
from brainpy.types import Shape, ArrayType
-from brainpy._src.dyn.base import NeuDyn
__all__ = [
'InputGroup',
@@ -21,12 +25,11 @@
class InputGroup(NeuDyn):
"""Input neuron group for place holder.
- Parameters
- ----------
- size: int, tuple of int
- keep_size: bool
- mode: Mode
- name: str
+ Args:
+ size: int, tuple of int
+ keep_size: bool
+ mode: Mode
+ name: str
"""
def __init__(
@@ -34,8 +37,8 @@ def __init__(
size: Union[int, Sequence[int]],
sharding: Any = None,
keep_size: bool = False,
- mode: bm.Mode = None,
- name: str = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
):
super(InputGroup, self).__init__(name=name,
sharding=sharding,
@@ -56,12 +59,11 @@ def reset_state(self, batch_size=None):
class OutputGroup(NeuDyn):
"""Output neuron group for place holder.
- Parameters
- ----------
- size: int, tuple of int
- keep_size: bool
- mode: Mode
- name: str
+ Args:
+ size: int, tuple of int
+ keep_size: bool
+ mode: Mode
+ name: str
"""
def __init__(
@@ -69,24 +71,25 @@ def __init__(
size: Union[int, Sequence[int]],
sharding: Any = None,
keep_size: bool = False,
- mode: bm.Mode = None,
- name: str = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
):
super(OutputGroup, self).__init__(name=name,
sharding=sharding,
size=size,
keep_size=keep_size,
mode=mode)
- self.spike = None
def update(self, x):
- return bm.sharding.partition(x, sharding=self.sharding)
+ return x
+
+ def return_info(self):
+ return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros)
def reset_state(self, batch_size=None):
pass
-
class SpikeTimeGroup(NeuDyn):
"""The input neuron group characterized by spikes emitting at given times.
@@ -120,10 +123,11 @@ def __init__(
size: Union[int, Sequence[int]],
indices: Union[Sequence, ArrayType],
times: Union[Sequence, ArrayType],
- name: str = None,
- sharding: Any = None,
+ spk_type: Optional[type] = None,
+ name: Optional[str] = None,
+ sharding: Optional[Sequence[str]] = None,
keep_size: bool = False,
- mode: bm.Mode = None,
+ mode: Optional[bm.Mode] = None,
need_sort: bool = True,
):
super(SpikeTimeGroup, self).__init__(size=size,
@@ -139,6 +143,7 @@ def __init__(
raise ValueError(f'The length of "indices" and "times" must be the same. '
f'However, we got {len(indices)} != {len(times)}.')
self.num_times = len(times)
+ self.spk_type = get_spk_type(spk_type, self.mode)
# data about times and indices
self.times = bm.asarray(times)
@@ -153,22 +158,26 @@ def __init__(
def reset_state(self, batch_size=None):
self.i = bm.Variable(bm.asarray(0))
- self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size)
+ self.spike = variable_(partial(jnp.zeros, dtype=self.spk_type),
+ self.varshape,
+ batch_size,
+ axis_names=self.sharding,
+ batch_axis_name=bm.sharding.BATCH_AXIS)
def update(self):
- self.spike.value = bm.zeros_like(self.spike)
- bm.while_loop(self._body_fun, self._cond_fun, share.load('t'))
+ self.spike.value = bm.sharding.partition(bm.zeros_like(self.spike), self.spike.sharding)
+ bm.while_loop(self._body_fun, self._cond_fun, ())
return self.spike.value
def return_info(self):
return self.spike
# functions
- def _cond_fun(self, t):
+ def _cond_fun(self):
i = self.i.value
- return bm.logical_and(i < self.num_times, t >= self.times[i])
+ return bm.logical_and(i < self.num_times, share['t'] >= self.times[i])
- def _body_fun(self, t):
+ def _body_fun(self):
i = self.i.value
if isinstance(self.mode, bm.BatchingMode):
self.spike[:, self.indices[i]] = True
@@ -184,12 +193,12 @@ class PoissonGroup(NeuDyn):
def __init__(
self,
size: Shape,
- freqs: Union[int, float, jnp.ndarray, bm.Array, Initializer],
- seed: int = None,
- name: str = None,
- sharding: Any = None,
+ freqs: Union[int, float, jax.Array, bm.Array, Callable],
keep_size: bool = False,
- mode: bm.Mode = None,
+ sharding: Optional[Sequence[str]] = None,
+ spk_type: Optional[type] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
):
super(PoissonGroup, self).__init__(size=size,
sharding=sharding,
@@ -198,15 +207,16 @@ def __init__(
mode=mode)
# parameters
- self.keep_size = keep_size
- self.seed = seed
self.freqs = parameter(freqs, self.num, allow_none=False)
+ self.spk_type = get_spk_type(spk_type, self.mode)
# variables
self.reset_state(self.mode)
def update(self):
spikes = bm.random.rand_like(self.spike) <= (self.freqs * share.dt / 1000.)
+ spikes = bm.asarray(spikes, dtype=self.spk_type)
+ spikes = bm.sharding.partition(spikes, self.spike.sharding)
self.spike.value = spikes
return spikes
@@ -214,7 +224,8 @@ def return_info(self):
return self.spike
def reset_state(self, batch_size=None):
- self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size)
-
-
-
+ self.spike = variable_(partial(jnp.zeros, dtype=self.spk_type),
+ self.varshape,
+ batch_size,
+ axis_names=self.sharding,
+ batch_axis_name=bm.sharding.BATCH_AXIS)
diff --git a/brainpy/_src/neurons/noise_groups.py b/brainpy/_src/dyn/others/noise.py
similarity index 68%
rename from brainpy/_src/neurons/noise_groups.py
rename to brainpy/_src/dyn/others/noise.py
index 41f09e1ce..255d3f1f1 100644
--- a/brainpy/_src/neurons/noise_groups.py
+++ b/brainpy/_src/dyn/others/noise.py
@@ -1,21 +1,20 @@
-# -*- coding: utf-8 -*-
-
from typing import Union, Callable
import jax.numpy as jnp
+
+import brainpy.math as bm
from brainpy._src.context import share
-from brainpy import math as bm, initialize as init
-from brainpy._src.dynsys import NeuGroupNS
-from brainpy._src.initialize import Initializer
+from brainpy._src.dynsys import NeuDyn
+from brainpy._src.initialize import variable_, parameter
from brainpy._src.integrators.sde.generic import sdeint
-from brainpy.types import ArrayType, Shape
+from brainpy.types import Shape, ArrayType
__all__ = [
'OUProcess',
]
-class OUProcess(NeuGroupNS):
+class OUProcess(NeuDyn):
r"""The Ornstein–Uhlenbeck process.
The Ornstein–Uhlenbeck process :math:`x_{t}` is defined by the following
@@ -47,9 +46,9 @@ class OUProcess(NeuGroupNS):
def __init__(
self,
size: Shape,
- mean: Union[float, ArrayType, Initializer, Callable] = 0.,
- sigma: Union[float, ArrayType, Initializer, Callable] = 1.,
- tau: Union[float, ArrayType, Initializer, Callable] = 10.,
+ mean: Union[float, ArrayType, Callable] = 0.,
+ sigma: Union[float, ArrayType, Callable] = 1.,
+ tau: Union[float, ArrayType, Callable] = 10.,
method: str = 'exp_euler',
keep_size: bool = False,
mode: bm.Mode = None,
@@ -58,9 +57,9 @@ def __init__(
super(OUProcess, self).__init__(size=size, name=name, keep_size=keep_size, mode=mode)
# parameters
- self.mean = init.parameter(mean, self.varshape, allow_none=False)
- self.sigma = init.parameter(sigma, self.varshape, allow_none=False)
- self.tau = init.parameter(tau, self.varshape, allow_none=False)
+ self.mean = parameter(mean, self.varshape, allow_none=False)
+ self.sigma = parameter(sigma, self.varshape, allow_none=False)
+ self.tau = parameter(tau, self.varshape, allow_none=False)
# variables
self.reset_state(self.mode)
@@ -69,7 +68,7 @@ def __init__(
self.integral = sdeint(f=self.df, g=self.dg, method=method)
def reset_state(self, batch_size=None):
- self.x = init.variable_(lambda s: jnp.ones(s) * self.mean, self.varshape, batch_size)
+ self.x = variable_(lambda s: jnp.ones(s) * self.mean, self.varshape, batch_size)
def df(self, x, t):
return (self.mean - x) / self.tau
@@ -82,4 +81,3 @@ def update(self):
dt = share.load('dt')
self.x.value = self.integral(self.x, t, dt)
return self.x.value
-
diff --git a/brainpy/_src/dyn/neurons/tests/test_input.py b/brainpy/_src/dyn/others/tests/test_input.py
similarity index 94%
rename from brainpy/_src/dyn/neurons/tests/test_input.py
rename to brainpy/_src/dyn/others/tests/test_input.py
index fc05c62b8..c1630c38d 100644
--- a/brainpy/_src/dyn/neurons/tests/test_input.py
+++ b/brainpy/_src/dyn/others/tests/test_input.py
@@ -3,7 +3,7 @@
import brainpy as bp
from absl.testing import parameterized
-from brainpy._src.dyn.neurons import input
+from brainpy._src.dyn.others import input
class Test_input(parameterized.TestCase):
diff --git a/brainpy/_src/neurons/tests/test_input_groups.py b/brainpy/_src/dyn/others/tests/test_input_groups.py
similarity index 87%
rename from brainpy/_src/neurons/tests/test_input_groups.py
rename to brainpy/_src/dyn/others/tests/test_input_groups.py
index 17ae99168..1028bcc8e 100644
--- a/brainpy/_src/neurons/tests/test_input_groups.py
+++ b/brainpy/_src/dyn/others/tests/test_input_groups.py
@@ -8,17 +8,21 @@
class Test_input_Group(parameterized.TestCase):
def test_SpikeTimeGroup(self):
+ bp.math.random.seed()
model = input_groups.SpikeTimeGroup(size=2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
runner = bp.DSRunner(model,
monitors=['spike'],
progress_bar=False)
runner.run(30.)
self.assertTupleEqual(runner.mon['spike'].shape, (300, 2))
+ bp.math.clear_buffer_memory()
def test_PoissonGroup(self):
+ bp.math.random.seed()
model = input_groups.PoissonGroup(size=2, freqs=1000, seed=0)
runner = bp.DSRunner(model,
monitors=['spike'],
progress_bar=False)
runner.run(30.)
self.assertTupleEqual(runner.mon['spike'].shape, (300, 2))
+ bp.math.clear_buffer_memory()
diff --git a/brainpy/_src/neurons/tests/test_noise_groups.py b/brainpy/_src/dyn/others/tests/test_noise_groups.py
similarity index 88%
rename from brainpy/_src/neurons/tests/test_noise_groups.py
rename to brainpy/_src/dyn/others/tests/test_noise_groups.py
index 8ebb3ed7e..2fc831e61 100644
--- a/brainpy/_src/neurons/tests/test_noise_groups.py
+++ b/brainpy/_src/dyn/others/tests/test_noise_groups.py
@@ -18,4 +18,5 @@ def test_OU(self):
self.assertTupleEqual(runner.mon['x'].shape, (100, 1))
x = runner.mon['x']
self.assertLessEqual(abs(x.mean()), 0.1)
- self.assertLessEqual(abs(x.std() - 0.1), 0.1)
\ No newline at end of file
+ self.assertLessEqual(abs(x.std() - 0.1), 0.1)
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/dyn/outs/__init__.py b/brainpy/_src/dyn/outs/__init__.py
new file mode 100644
index 000000000..ac55893ee
--- /dev/null
+++ b/brainpy/_src/dyn/outs/__init__.py
@@ -0,0 +1,2 @@
+from .base import *
+from .outputs import *
diff --git a/brainpy/_src/dyn/outs/base.py b/brainpy/_src/dyn/outs/base.py
new file mode 100644
index 000000000..0a0da5dbd
--- /dev/null
+++ b/brainpy/_src/dyn/outs/base.py
@@ -0,0 +1,21 @@
+from typing import Optional
+
+from brainpy._src.dynsys import DynamicalSystem
+from brainpy._src.mixin import ParamDesc, BindCondData
+
+__all__ = [
+ 'SynOut'
+]
+
+
+class SynOut(DynamicalSystem, ParamDesc, BindCondData):
+ """Base class for synaptic outputs."""
+ def __init__(self, name: Optional[str] = None):
+ super().__init__(name=name)
+
+ def __call__(self, *args, **kwargs):
+ if self._conductance is None:
+ raise ValueError(f'Please first pack conductance data at the current step using '
+ f'".{BindCondData.bind_cond.__name__}(data)". {self}')
+ ret = self.update(self._conductance, *args, **kwargs)
+ return ret
diff --git a/brainpy/_src/dyn/synapses/outputs.py b/brainpy/_src/dyn/outs/outputs.py
similarity index 93%
rename from brainpy/_src/dyn/synapses/outputs.py
rename to brainpy/_src/dyn/outs/outputs.py
index bc9783e7b..9a6679d2d 100644
--- a/brainpy/_src/dyn/synapses/outputs.py
+++ b/brainpy/_src/dyn/outs/outputs.py
@@ -1,11 +1,10 @@
-
from typing import Union, Optional, Sequence
import numpy as np
from brainpy import math as bm, initialize as init
-from brainpy._src.dyn.base import SynOut
from brainpy.types import ArrayType
+from .base import SynOut
__all__ = [
'COBA',
@@ -27,6 +26,8 @@ class COBA(SynOut):
----------
E: float, ArrayType, ndarray
The reversal potential.
+ sharding: sequence of str
+ The axis names for variable for parallelization.
name: str
The model name.
@@ -37,7 +38,7 @@ class COBA(SynOut):
def __init__(
self,
- E: Union[float, ArrayType] = 0.,
+ E: Union[float, ArrayType],
sharding: Optional[Sequence[str]] = None,
name: Optional[str] = None,
):
@@ -64,18 +65,10 @@ class CUBA(SynOut):
name: str
The model name.
-
See Also
--------
COBA
"""
-
- def __init__(
- self,
- name: Optional[str] = None,
- ):
- super().__init__(name=name)
-
def update(self, conductance, potential=None):
return conductance
@@ -107,6 +100,8 @@ class MgBlock(SynOut):
Unbinding constant. Default 3.57
cc_Mg: float, ArrayType
Concentration of Magnesium ion. Default 1.2 [mM].
+ sharding: sequence of str
+ The axis names for variable for parallelization.
name: str
The model name.
"""
diff --git a/brainpy/_src/dyn/projections/__init__.py b/brainpy/_src/dyn/projections/__init__.py
new file mode 100644
index 000000000..e58f35554
--- /dev/null
+++ b/brainpy/_src/dyn/projections/__init__.py
@@ -0,0 +1,3 @@
+
+from .aligns import *
+from .others import *
diff --git a/brainpy/_src/dyn/projections.py b/brainpy/_src/dyn/projections/aligns.py
similarity index 70%
rename from brainpy/_src/dyn/projections.py
rename to brainpy/_src/dyn/projections/aligns.py
index 26af51abc..7ad9535c9 100644
--- a/brainpy/_src/dyn/projections.py
+++ b/brainpy/_src/dyn/projections/aligns.py
@@ -2,23 +2,16 @@
from brainpy import math as bm
from brainpy._src.delay import Delay, VariableDelay, DataDelay
-from brainpy._src.dyn.base import NeuDyn, SynOut
-from brainpy._src.dynsys import DynamicalSystemNS, DynamicalSystem
-from brainpy._src.mixin import DelayedInit, ReturnInfo, ProjAutoDelay
+from brainpy._src.dynsys import DynamicalSystem, Projection, NeuDyn
+from brainpy._src.mixin import JointType, ParamDesc, ParamDescInit, ReturnInfo, AutoDelaySupp, BindCondData, AlignPost
__all__ = [
- 'SynProj',
'ProjAlignPre',
'ProjAlignPost',
]
-class SynProj(DynamicalSystemNS):
- """Synaptic projection."""
- pass
-
-
-class _AlignPre(DynamicalSystemNS):
+class _AlignPre(DynamicalSystem):
def __init__(self, syn, delay=None):
super().__init__()
self.syn = syn
@@ -31,8 +24,10 @@ def update(self, x):
return x >> self.syn >> self.delay
-class _AlignPost(DynamicalSystemNS):
- def __init__(self, syn, out):
+class _AlignPost(DynamicalSystem):
+ def __init__(self,
+ syn: Callable,
+ out: JointType[DynamicalSystem, BindCondData]):
super().__init__()
self.syn = syn
self.out = out
@@ -65,7 +60,7 @@ def _init_delay(info: Union[bm.Variable, ReturnInfo]) -> Delay:
raise TypeError
-class ProjAlignPre(SynProj):
+class ProjAlignPre(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
Args:
@@ -81,11 +76,11 @@ class ProjAlignPre(SynProj):
def __init__(
self,
- pre: NeuDyn,
- syn: DelayedInit[ProjAutoDelay],
+ pre: DynamicalSystem,
+ syn: ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]],
delay: Union[None, int, float],
comm: Callable,
- out: SynOut,
+ out: JointType[DynamicalSystem, BindCondData],
post: NeuDyn,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
@@ -93,11 +88,11 @@ def __init__(
super().__init__(name=name, mode=mode)
# synaptic models
- assert isinstance(pre, NeuDyn)
+ assert isinstance(pre, DynamicalSystem)
+ assert isinstance(syn, ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]])
+ assert isinstance(comm, Callable)
+ assert isinstance(out, JointType[DynamicalSystem, BindCondData])
assert isinstance(post, NeuDyn)
- assert callable(comm)
- assert isinstance(out, SynOut)
- assert isinstance(syn, DelayedInit) and issubclass(syn.cls, ProjAutoDelay)
self.pre = pre
self.post = post
self.comm = comm
@@ -105,8 +100,10 @@ def __init__(
# synapse and delay initialization
self._syn_id = syn._identifier
if self._syn_id not in pre.after_updates:
- syn_cls: ProjAutoDelay = syn()
+ # "syn_cls" needs an instance of "ProjAutoDelay"
+ syn_cls: AutoDelaySupp = syn()
delay_cls = _init_delay(syn_cls.return_info())
+ # add to "after_updates"
pre.after_updates[self._syn_id] = _AlignPre(syn_cls, delay_cls)
delay_cls: Delay = pre.after_updates[self._syn_id].delay
delay_cls.register_entry(self.name, delay)
@@ -114,13 +111,15 @@ def __init__(
# output initialization
post.cur_inputs[self.name] = out
- def update(self):
- current = self.comm(self.pre.after_updates[self._syn_id].delay.at(self.name))
+ def update(self, x=None):
+ if x is None:
+ x = self.pre.after_updates[self._syn_id].delay.at(self.name)
+ current = self.comm(x)
self.post.cur_inputs[self.name].bind_cond(current)
return current
-class ProjAlignPost(SynProj):
+class ProjAlignPost(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
Args:
@@ -136,11 +135,11 @@ class ProjAlignPost(SynProj):
def __init__(
self,
- pre: ProjAutoDelay,
+ pre: JointType[DynamicalSystem, AutoDelaySupp],
delay: Union[None, int, float],
comm: Callable,
- syn: DelayedInit[DynamicalSystem],
- out: DelayedInit[SynOut],
+ syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]],
+ out: ParamDescInit[JointType[DynamicalSystem, BindCondData]],
post: NeuDyn,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
@@ -148,11 +147,11 @@ def __init__(
super().__init__(name=name, mode=mode)
# synaptic models
- assert isinstance(pre, NeuDyn) and isinstance(pre, ProjAutoDelay)
+ assert isinstance(pre, JointType[DynamicalSystem, AutoDelaySupp])
+ assert isinstance(comm, Callable)
+ assert isinstance(syn, ParamDescInit[JointType[DynamicalSystem, AlignPost]])
+ assert isinstance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]])
assert isinstance(post, NeuDyn)
- assert isinstance(syn, DelayedInit) and issubclass(syn.cls, DynamicalSystem)
- assert isinstance(out, DelayedInit) and issubclass(out.cls, SynOut)
- assert callable(comm)
self.pre = pre
self.post = post
self.comm = comm
@@ -160,7 +159,9 @@ def __init__(
# delay initialization
self._delay_repr = '_*_align_pre_spk_delay_*_'
if self._delay_repr not in self.pre.after_updates:
+ # pre should support "ProjAutoDelay"
delay_cls = _init_delay(pre.return_info())
+ # add to "after_updates"
self.pre.after_updates[self._delay_repr] = delay_cls
delay_cls: Delay = pre.after_updates[self._delay_repr]
delay_cls.register_entry(self.name, delay)
@@ -173,7 +174,9 @@ def __init__(
self.post.cur_inputs[self.name] = out_cls
self.post.before_updates[self._post_repr] = _AlignPost(syn_cls, out_cls)
- def update(self):
- current = self.comm(self.pre.after_updates[self._delay_repr].at(self.name))
+ def update(self, x=None):
+ if x is None:
+ x = self.pre.after_updates[self._delay_repr].at(self.name)
+ current = self.comm(x)
self.post.before_updates[self._post_repr].syn.add_current(current) # synapse post current
return current
diff --git a/brainpy/_src/dyn/projections/others.py b/brainpy/_src/dyn/projections/others.py
new file mode 100644
index 000000000..506382e2e
--- /dev/null
+++ b/brainpy/_src/dyn/projections/others.py
@@ -0,0 +1,73 @@
+import numbers
+from typing import Union, Optional
+
+from brainpy import check, math as bm
+from brainpy._src.context import share
+from brainpy._src.dynsys import Projection
+
+__all__ = [
+ 'PoissonInput',
+]
+
+
+class PoissonInput(Projection):
+ """Poisson Input to the given :py:class:`~.Variable`.
+
+ Adds independent Poisson input to a target variable. For large
+ numbers of inputs, this is much more efficient than creating a
+ `PoissonGroup`. The synaptic events are generated randomly during the
+ simulation and are not preloaded and stored in memory. All the inputs must
+ target the same variable, have the same frequency and same synaptic weight.
+ All neurons in the target variable receive independent realizations of
+ Poisson spike trains.
+
+ Args:
+ target_var: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`.
+ num_input: The number of inputs.
+ freq: The frequency of each of the inputs. Must be a scalar.
+ weight: The synaptic weight. Must be a scalar.
+ name: The target name.
+ mode: The computing mode.
+ """
+
+ def __init__(
+ self,
+ target_var: bm.Variable,
+ num_input: int,
+ freq: Union[int, float],
+ weight: Union[int, float],
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None
+ ):
+ super().__init__(name=name, mode=mode)
+
+ if not isinstance(target_var, bm.Variable):
+ raise TypeError(f'"target_var" must be an instance of Variable. '
+ f'But we got {type(target_var)}: {target_var}')
+ self.target_var = target_var
+ self.num_input = check.is_integer(num_input, min_bound=1)
+ self.freq = check.is_float(freq, min_bound=0., allow_int=True)
+ self.weight = check.is_float(weight, allow_int=True)
+
+ def update(self):
+ p = self.freq * share['dt'] / 1e3
+ a = self.num_input * p
+ b = self.num_input * (1 - p)
+
+ if isinstance(share['dt'], numbers.Number): # dt is not traced
+ if (a > 5) and (b > 5):
+ inp = bm.random.normal(a, b * p, self.target_var.shape)
+ else:
+ inp = bm.random.binomial(self.num_input, p, self.target_var.shape)
+
+ else: # dt is traced
+ inp = bm.cond((a > 5) * (b > 5),
+ lambda: bm.random.normal(a, b * p, self.target_var.shape),
+ lambda: bm.random.binomial(self.num_input, p, self.target_var.shape),
+ ())
+
+ inp = bm.sharding.partition(inp, self.target_var.sharding)
+ self.target_var += inp * self.weight
+
+ def __repr__(self):
+ return f'{self.name}(num_input={self.num_input}, freq={self.freq}, weight={self.weight})'
diff --git a/brainpy/_src/rates/__init__.py b/brainpy/_src/dyn/rates/__init__.py
similarity index 100%
rename from brainpy/_src/rates/__init__.py
rename to brainpy/_src/dyn/rates/__init__.py
diff --git a/brainpy/_src/rates/populations.py b/brainpy/_src/dyn/rates/populations.py
similarity index 99%
rename from brainpy/_src/rates/populations.py
rename to brainpy/_src/dyn/rates/populations.py
index c216eb365..afea3c4b2 100644
--- a/brainpy/_src/rates/populations.py
+++ b/brainpy/_src/dyn/rates/populations.py
@@ -1,12 +1,13 @@
# -*- coding: utf-8 -*-
from typing import Union, Callable
+
import jax
from brainpy import math as bm
from brainpy._src.context import share
-from brainpy._src.dynsys import NeuGroupNS
-from brainpy._src.neurons.noise_groups import OUProcess
+from brainpy._src.dyn.others.noise import OUProcess
+from brainpy._src.dynsys import NeuDyn
from brainpy._src.initialize import (Initializer,
Uniform,
parameter,
@@ -28,7 +29,7 @@
]
-class RateModel(NeuGroupNS):
+class RateModel(NeuDyn):
pass
diff --git a/brainpy/_src/rates/tests/test_rates.py b/brainpy/_src/dyn/rates/tests/test_rates.py
similarity index 98%
rename from brainpy/_src/rates/tests/test_rates.py
rename to brainpy/_src/dyn/rates/tests/test_rates.py
index 7e1de6cc9..88c016705 100644
--- a/brainpy/_src/rates/tests/test_rates.py
+++ b/brainpy/_src/dyn/rates/tests/test_rates.py
@@ -3,7 +3,7 @@
import brainpy as bp
from absl.testing import parameterized
-from brainpy._src.rates import populations
+from brainpy._src.dyn.rates import populations
from unittest import TestCase
diff --git a/brainpy/_src/dyn/synapses/__init__.py b/brainpy/_src/dyn/synapses/__init__.py
index e69de29bb..2a296acb5 100644
--- a/brainpy/_src/dyn/synapses/__init__.py
+++ b/brainpy/_src/dyn/synapses/__init__.py
@@ -0,0 +1,3 @@
+
+from .abstract_models import *
+from .bio_models import *
diff --git a/brainpy/_src/dyn/synapses/dynamics.py b/brainpy/_src/dyn/synapses/abstract_models.py
similarity index 61%
rename from brainpy/_src/dyn/synapses/dynamics.py
rename to brainpy/_src/dyn/synapses/abstract_models.py
index cda03d7a4..421cc086c 100644
--- a/brainpy/_src/dyn/synapses/dynamics.py
+++ b/brainpy/_src/dyn/synapses/abstract_models.py
@@ -1,27 +1,90 @@
from typing import Union, Sequence, Callable, Optional
+import jax.numpy
from brainpy import math as bm
from brainpy._src.context import share
from brainpy._src.dyn._docs import pneu_doc
-from brainpy._src.dyn.base import SynDyn
+from brainpy._src.dynsys import SynDyn
from brainpy._src.integrators.joint_eq import JointEq
from brainpy._src.integrators.ode.generic import odeint
from brainpy._src.mixin import AlignPost, ReturnInfo
+from brainpy._src.initialize import Constant
from brainpy.types import ArrayType
__all__ = [
+ 'Delta',
'Expon',
'DualExpon',
'Alpha',
'NMDA',
'STD',
'STP',
- 'AMPA',
- 'GABAa',
- 'BioNMDA',
]
+class Delta(SynDyn, AlignPost):
+ r"""Delta synapse model.
+
+ **Model Descriptions**
+
+ The single exponential decay synapse model assumes the release of neurotransmitter,
+ its diffusion across the cleft, the receptor binding, and channel opening all happen
+ very quickly, so that the channels instantaneously jump from the closed to the open state.
+ Therefore, its expression is given by
+
+ .. math::
+
+ g_{\mathrm{syn}}(t)=g_{\mathrm{max}} e^{-\left(t-t_{0}\right) / \tau}
+
+ where :math:`\tau_{delay}` is the time constant of the synaptic state decay,
+ :math:`t_0` is the time of the pre-synaptic spike,
+ :math:`g_{\mathrm{max}}` is the maximal conductance.
+
+ Accordingly, the differential form of the exponential synapse is given by
+
+ .. math::
+
+ \begin{aligned}
+ & \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}).
+ \end{aligned}
+
+ .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
+ "The Synapse." Principles of Computational Modelling in Neuroscience.
+ Cambridge: Cambridge UP, 2011. 172-95. Print.
+
+ Args:
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ sharding: Optional[Sequence[str]] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name,
+ mode=mode,
+ size=size,
+ keep_size=keep_size,
+ sharding=sharding)
+
+ self.reset_state(self.mode)
+
+ def reset_state(self, batch_size=None):
+ self.g = self.init_variable(bm.zeros, batch_size)
+
+ def update(self, x=None):
+ if x is not None:
+ self.g.value += x
+ return self.g.value
+
+ def add_current(self, x):
+ self.g.value += x
+
+ def return_info(self):
+ return self.g
+
class Expon(SynDyn, AlignPost):
r"""Exponential decay synapse model.
@@ -533,14 +596,18 @@ def reset_state(self, batch_size=None):
def derivative(self):
du = lambda u, t: self.U - u / self.tau_f
dx = lambda x, t: (1 - x) / self.tau_d
- return JointEq([du, dx])
+ return JointEq(du, dx)
def update(self, pre_spike):
- t = share.load('x')
+ t = share.load('t')
dt = share.load('dt')
u, x = self.integral(self.u.value, self.x.value, t, dt)
- u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
- x = bm.where(pre_spike, x - u * self.x, x)
+ if pre_spike.dtype == jax.numpy.bool_:
+ u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
+ x = bm.where(pre_spike, x - u * self.x, x)
+ else:
+ u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
+ x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
self.x.value = x
self.u.value = u
return u * x
@@ -549,316 +616,8 @@ def return_info(self):
return ReturnInfo(size=self.varshape,
batch_or_mode=self.mode,
axis_names=self.sharding,
- init=bm.zeros)
+ init=Constant(self.U))
STP.__doc__ = STP.__doc__ % (pneu_doc,)
-
-class AMPA(SynDyn):
- r"""AMPA synapse model.
-
- **Model Descriptions**
-
- AMPA receptor is an ionotropic receptor, which is an ion channel.
- When it is bound by neurotransmitters, it will immediately open the
- ion channel, causing the change of membrane potential of postsynaptic neurons.
-
- A classical model is to use the Markov process to model ion channel switch.
- Here :math:`g` represents the probability of channel opening, :math:`1-g`
- represents the probability of ion channel closing, and :math:`\alpha` and
- :math:`\beta` are the transition probability. Because neurotransmitters can
- open ion channels, the transfer probability from :math:`1-g` to :math:`g`
- is affected by the concentration of neurotransmitters. We denote the concentration
- of neurotransmitters as :math:`[T]` and get the following Markov process.
-
- .. image:: ../../../_static/synapse_markov.png
- :align: center
-
- We obtained the following formula when describing the process by a differential equation.
-
- .. math::
-
- \frac{ds}{dt} =\alpha[T](1-g)-\beta g
-
- where :math:`\alpha [T]` denotes the transition probability from state :math:`(1-g)`
- to state :math:`(g)`; and :math:`\beta` represents the transition probability of
- the other direction. :math:`\alpha` is the binding constant. :math:`\beta` is the
- unbinding constant. :math:`[T]` is the neurotransmitter concentration, and
- has the duration of 0.5 ms.
-
- Moreover, the post-synaptic current on the post-synaptic neuron is formulated as
-
- .. math::
-
- I_{syn} = g_{max} g (V-E)
-
- where :math:`g_{max}` is the maximum conductance, and `E` is the reverse potential.
-
- .. [1] Vijayan S, Kopell N J. Thalamic model of awake alpha oscillations
- and implications for stimulus processing[J]. Proceedings of the
- National Academy of Sciences, 2012, 109(45): 18553-18558.
-
- Args:
- alpha: float, ArrayType, Callable. Binding constant.
- beta: float, ArrayType, Callable. Unbinding constant.
- T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by
- a pre-synaptic spike.. Default 1 [mM].
- T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms]
- %s
- """
-
- supported_modes = (bm.NonBatchingMode,)
-
- def __init__(
- self,
- size: Union[int, Sequence[int]],
- keep_size: bool = False,
- sharding: Optional[Sequence[str]] = None,
- method: str = 'exp_auto',
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
-
- # synapse parameters
- alpha: Union[float, ArrayType, Callable] = 0.98,
- beta: Union[float, ArrayType, Callable] = 0.18,
- T: Union[float, ArrayType, Callable] = 0.5,
- T_dur: Union[float, ArrayType, Callable] = 0.5,
- ):
- super().__init__(name=name,
- mode=mode,
- size=size,
- keep_size=keep_size,
- sharding=sharding)
-
- # parameters
- self.alpha = self.init_param(alpha)
- self.beta = self.init_param(beta)
- self.T = self.init_param(T)
- self.T_duration = self.init_param(T_dur)
-
- # functions
- self.integral = odeint(method=method, f=self.dg)
-
- self.reset_state(self.mode)
-
- def reset_state(self, batch_size=None):
- self.g = self.init_variable(bm.zeros, batch_size)
- self.spike_arrival_time = self.init_variable(bm.ones, batch_size)
- self.spike_arrival_time.fill(-1e7)
-
- def dg(self, g, t, TT):
- return self.alpha * TT * (1 - g) - self.beta * g
-
- def update(self, pre_spike):
- t = share.load('t')
- dt = share.load('dt')
- self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time)
- TT = ((t - self.spike_arrival_time) < self.T_duration) * self.T
- self.g.value = self.integral(self.g, t, TT, dt)
- return self.g.value
-
- def return_info(self):
- return self.g
-
-
-AMPA.__doc__ = AMPA.__doc__ % (pneu_doc,)
-
-
-class GABAa(AMPA):
- r"""GABAa synapse model.
-
- **Model Descriptions**
-
- GABAa synapse model has the same equation with the `AMPA synapse <./brainmodels.synapses.AMPA.rst>`_,
-
- .. math::
-
- \frac{d g}{d t}&=\alpha[T](1-g) - \beta g \\
- I_{syn}&= - g_{max} g (V - E)
-
- but with the difference of:
-
- - Reversal potential of synapse :math:`E` is usually low, typically -80. mV
- - Activating rate constant :math:`\alpha=0.53`
- - De-activating rate constant :math:`\beta=0.18`
- - Transmitter concentration :math:`[T]=1\,\mu ho(\mu S)` when synapse is
- triggered by a pre-synaptic spike, with the duration of 1. ms.
-
- .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity
- on the integrative properties of neocortical pyramidal neurons
- in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547.
-
- Args:
- alpha: float, ArrayType, Callable. Binding constant. Default 0.062
- beta: float, ArrayType, Callable. Unbinding constant. Default 3.57
- T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by
- a pre-synaptic spike.. Default 1 [mM].
- T_dur: float, ArrayType, Callable. Transmitter concentration duration time
- after being triggered. Default 1 [ms]
- %s
- """
-
- def __init__(
- self,
- size: Union[int, Sequence[int]],
- keep_size: bool = False,
- sharding: Optional[Sequence[str]] = None,
- method: str = 'exp_auto',
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
-
- # synapse parameters
- alpha: Union[float, ArrayType, Callable] = 0.53,
- beta: Union[float, ArrayType, Callable] = 0.18,
- T: Union[float, ArrayType, Callable] = 1.,
- T_dur: Union[float, ArrayType, Callable] = 1.,
- ):
- super().__init__(alpha=alpha,
- beta=beta,
- T=T,
- T_dur=T_dur,
- method=method,
- name=name,
- mode=mode,
- size=size,
- keep_size=keep_size,
- sharding=sharding)
-
-
-GABAa.__doc__ = GABAa.__doc__ % (pneu_doc,)
-
-
-class BioNMDA(SynDyn):
- r"""Biological NMDA synapse model.
-
- **Model Descriptions**
-
- The NMDA receptor is a glutamate receptor and ion channel found in neurons.
- The NMDA receptor is one of three types of ionotropic glutamate receptors,
- the other two being AMPA and kainate receptors.
-
- The NMDA receptor mediated conductance depends on the postsynaptic voltage.
- The voltage dependence is due to the blocking of the pore of the NMDA receptor
- from the outside by a positively charged magnesium ion. The channel is
- nearly completely blocked at resting potential, but the magnesium block is
- relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}`
- that are not blocked by magnesium can be fitted to
-
- .. math::
-
- g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-a V}
- \frac{[{Mg}^{2+}]_{o}} {b})^{-1}
-
- Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration,
- usually 1 mM. Thus, the channel acts as a
- "coincidence detector" and only once both of these conditions are met, the
- channel opens and it allows positively charged ions (cations) to flow through
- the cell membrane [2]_.
-
- If we make the approximation that the magnesium block changes
- instantaneously with voltage and is independent of the gating of the channel,
- the net NMDA receptor-mediated synaptic current is given by
-
- .. math::
-
- I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty}
-
- where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the
- reversal potential.
-
- Simultaneously, the kinetics of synaptic state :math:`g` is determined by a 2nd-order kinetics [1]_:
-
- .. math::
-
- & \frac{d g}{dt} = \alpha_1 x (1 - g) - \beta_1 g \\
- & \frac{d x}{dt} = \alpha_2 [T] (1 - x) - \beta_2 x
-
- where :math:`\alpha_1, \beta_1` refers to the conversion rate of variable g and
- :math:`\alpha_2, \beta_2` refers to the conversion rate of variable x.
-
- The NMDA receptor has been thought to be very important for controlling
- synaptic plasticity and mediating learning and memory functions [3]_.
-
- .. [1] Devaney A J . Mathematical Foundations of Neuroscience[M].
- Springer New York, 2010: 162.
- .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and
- Eric Gouaux. "Subunit arrangement and function in NMDA receptors."
- Nature 438, no. 7065 (2005): 185-192.
- .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New
- England journal of medicine, 361(3), p.302.
- .. [4] https://en.wikipedia.org/wiki/NMDA_receptor
-
-
- Args:
- alpha1: float, ArrayType, Callable. The conversion rate of g from inactive to active. Default 2 ms^-1.
- beta1: float, ArrayType, Callable. The conversion rate of g from active to inactive. Default 0.01 ms^-1.
- alpha2: float, ArrayType, Callable. The conversion rate of x from inactive to active. Default 1 ms^-1.
- beta2: float, ArrayType, Callable. The conversion rate of x from active to inactive. Default 0.5 ms^-1.
- T: float, ArrayType, Callable. Transmitter concentration when synapse is
- triggered by a pre-synaptic spike. Default 1 [mM].
- T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms]
- %s
- """
- supported_modes = (bm.NonBatchingMode,)
-
- def __init__(
- self,
- size: Union[int, Sequence[int]],
- keep_size: bool = False,
- sharding: Optional[Sequence[str]] = None,
- method: str = 'exp_auto',
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
-
- # synapse parameters
- alpha1: Union[float, ArrayType, Callable] = 2.,
- beta1: Union[float, ArrayType, Callable] = 0.01,
- alpha2: Union[float, ArrayType, Callable] = 1.,
- beta2: Union[float, ArrayType, Callable] = 0.5,
- T: Union[float, ArrayType, Callable] = 1.,
- T_dur: Union[float, ArrayType, Callable] = 0.5,
- ):
- super().__init__(name=name,
- mode=mode,
- size=size,
- keep_size=keep_size,
- sharding=sharding)
-
- # parameters
- self.beta1 = self.init_param(beta1)
- self.beta2 = self.init_param(beta2)
- self.alpha1 = self.init_param(alpha1)
- self.alpha2 = self.init_param(alpha2)
- self.T = self.init_param(T)
- self.T_dur = self.init_param(T_dur)
-
- # integral
- self.integral = odeint(method=method, f=JointEq([self.dg, self.dx]))
-
- self.reset_state(self.mode)
-
- def reset_state(self, batch_size=None):
- self.g = self.init_variable(bm.zeros, batch_size)
- self.x = self.init_variable(bm.zeros, batch_size)
- self.spike_arrival_time = self.init_variable(bm.ones, batch_size)
- self.spike_arrival_time.fill(-1e7)
-
- def dg(self, g, t, x):
- return self.alpha1 * x * (1 - g) - self.beta1 * g
-
- def dx(self, x, t, T):
- return self.alpha2 * T * (1 - x) - self.beta2 * x
-
- def update(self, pre_spike):
- t = share.load('t')
- dt = share.load('dt')
- self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time)
- T = ((t - self.spike_arrival_time) < self.T_dur) * self.T
- self.g.value, self.x.value = self.integral(self.g, self.x, t, T, dt)
- return self.g.value
-
- def return_info(self):
- return self.g
-
-BioNMDA.__doc__ = BioNMDA.__doc__ % (pneu_doc,)
diff --git a/brainpy/_src/dyn/synapses/bio_models.py b/brainpy/_src/dyn/synapses/bio_models.py
new file mode 100644
index 000000000..fd182380a
--- /dev/null
+++ b/brainpy/_src/dyn/synapses/bio_models.py
@@ -0,0 +1,328 @@
+from typing import Union, Sequence, Callable, Optional
+
+import jax.numpy
+from brainpy import math as bm
+from brainpy._src.context import share
+from brainpy._src.dyn._docs import pneu_doc
+from brainpy._src.dynsys import SynDyn
+from brainpy._src.integrators.joint_eq import JointEq
+from brainpy._src.integrators.ode.generic import odeint
+from brainpy._src.mixin import AlignPost, ReturnInfo
+from brainpy._src.initialize import Constant
+from brainpy.types import ArrayType
+
+__all__ = [
+ 'AMPA',
+ 'GABAa',
+ 'BioNMDA',
+]
+
+
+class AMPA(SynDyn):
+ r"""AMPA synapse model.
+
+ **Model Descriptions**
+
+ AMPA receptor is an ionotropic receptor, which is an ion channel.
+ When it is bound by neurotransmitters, it will immediately open the
+ ion channel, causing the change of membrane potential of postsynaptic neurons.
+
+ A classical model is to use the Markov process to model ion channel switch.
+ Here :math:`g` represents the probability of channel opening, :math:`1-g`
+ represents the probability of ion channel closing, and :math:`\alpha` and
+ :math:`\beta` are the transition probability. Because neurotransmitters can
+ open ion channels, the transfer probability from :math:`1-g` to :math:`g`
+ is affected by the concentration of neurotransmitters. We denote the concentration
+ of neurotransmitters as :math:`[T]` and get the following Markov process.
+
+ .. image:: ../../../_static/synapse_markov.png
+ :align: center
+
+ We obtained the following formula when describing the process by a differential equation.
+
+ .. math::
+
+ \frac{ds}{dt} =\alpha[T](1-g)-\beta g
+
+ where :math:`\alpha [T]` denotes the transition probability from state :math:`(1-g)`
+ to state :math:`(g)`; and :math:`\beta` represents the transition probability of
+ the other direction. :math:`\alpha` is the binding constant. :math:`\beta` is the
+ unbinding constant. :math:`[T]` is the neurotransmitter concentration, and
+ has the duration of 0.5 ms.
+
+ Moreover, the post-synaptic current on the post-synaptic neuron is formulated as
+
+ .. math::
+
+ I_{syn} = g_{max} g (V-E)
+
+ where :math:`g_{max}` is the maximum conductance, and `E` is the reverse potential.
+
+ .. [1] Vijayan S, Kopell N J. Thalamic model of awake alpha oscillations
+ and implications for stimulus processing[J]. Proceedings of the
+ National Academy of Sciences, 2012, 109(45): 18553-18558.
+
+ Args:
+ alpha: float, ArrayType, Callable. Binding constant.
+ beta: float, ArrayType, Callable. Unbinding constant.
+ T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by
+ a pre-synaptic spike.. Default 1 [mM].
+ T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms]
+ %s
+ """
+
+ supported_modes = (bm.NonBatchingMode, bm.BatchingMode)
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ sharding: Optional[Sequence[str]] = None,
+ method: str = 'exp_auto',
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+
+ # synapse parameters
+ alpha: Union[float, ArrayType, Callable] = 0.98,
+ beta: Union[float, ArrayType, Callable] = 0.18,
+ T: Union[float, ArrayType, Callable] = 0.5,
+ T_dur: Union[float, ArrayType, Callable] = 0.5,
+ ):
+ super().__init__(name=name,
+ mode=mode,
+ size=size,
+ keep_size=keep_size,
+ sharding=sharding)
+
+ # parameters
+ self.alpha = self.init_param(alpha)
+ self.beta = self.init_param(beta)
+ self.T = self.init_param(T)
+ self.T_duration = self.init_param(T_dur)
+
+ # functions
+ self.integral = odeint(method=method, f=self.dg)
+
+ self.reset_state(self.mode)
+
+ def reset_state(self, batch_size=None):
+ self.g = self.init_variable(bm.zeros, batch_size)
+ self.spike_arrival_time = self.init_variable(bm.ones, batch_size)
+ self.spike_arrival_time.fill(-1e7)
+
+ def dg(self, g, t, TT):
+ return self.alpha * TT * (1 - g) - self.beta * g
+
+ def update(self, pre_spike):
+ t = share.load('t')
+ dt = share.load('dt')
+ self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time)
+ TT = ((t - self.spike_arrival_time) < self.T_duration) * self.T
+ self.g.value = self.integral(self.g, t, TT, dt)
+ return self.g.value
+
+ def return_info(self):
+ return self.g
+
+
+AMPA.__doc__ = AMPA.__doc__ % (pneu_doc,)
+
+
+class GABAa(AMPA):
+ r"""GABAa synapse model.
+
+ **Model Descriptions**
+
+ GABAa synapse model has the same equation with the `AMPA synapse <./brainmodels.synapses.AMPA.rst>`_,
+
+ .. math::
+
+ \frac{d g}{d t}&=\alpha[T](1-g) - \beta g \\
+ I_{syn}&= - g_{max} g (V - E)
+
+ but with the difference of:
+
+ - Reversal potential of synapse :math:`E` is usually low, typically -80. mV
+ - Activating rate constant :math:`\alpha=0.53`
+ - De-activating rate constant :math:`\beta=0.18`
+ - Transmitter concentration :math:`[T]=1\,\mu ho(\mu S)` when synapse is
+ triggered by a pre-synaptic spike, with the duration of 1. ms.
+
+ .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity
+ on the integrative properties of neocortical pyramidal neurons
+ in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547.
+
+ Args:
+ alpha: float, ArrayType, Callable. Binding constant. Default 0.062
+ beta: float, ArrayType, Callable. Unbinding constant. Default 3.57
+ T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by
+ a pre-synaptic spike.. Default 1 [mM].
+ T_dur: float, ArrayType, Callable. Transmitter concentration duration time
+ after being triggered. Default 1 [ms]
+ %s
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ sharding: Optional[Sequence[str]] = None,
+ method: str = 'exp_auto',
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+
+ # synapse parameters
+ alpha: Union[float, ArrayType, Callable] = 0.53,
+ beta: Union[float, ArrayType, Callable] = 0.18,
+ T: Union[float, ArrayType, Callable] = 1.,
+ T_dur: Union[float, ArrayType, Callable] = 1.,
+ ):
+ super().__init__(alpha=alpha,
+ beta=beta,
+ T=T,
+ T_dur=T_dur,
+ method=method,
+ name=name,
+ mode=mode,
+ size=size,
+ keep_size=keep_size,
+ sharding=sharding)
+
+
+GABAa.__doc__ = GABAa.__doc__ % (pneu_doc,)
+
+
+class BioNMDA(SynDyn):
+ r"""Biological NMDA synapse model.
+
+ **Model Descriptions**
+
+ The NMDA receptor is a glutamate receptor and ion channel found in neurons.
+ The NMDA receptor is one of three types of ionotropic glutamate receptors,
+ the other two being AMPA and kainate receptors.
+
+ The NMDA receptor mediated conductance depends on the postsynaptic voltage.
+ The voltage dependence is due to the blocking of the pore of the NMDA receptor
+ from the outside by a positively charged magnesium ion. The channel is
+ nearly completely blocked at resting potential, but the magnesium block is
+ relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}`
+ that are not blocked by magnesium can be fitted to
+
+ .. math::
+
+ g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-a V}
+ \frac{[{Mg}^{2+}]_{o}} {b})^{-1}
+
+ Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration,
+ usually 1 mM. Thus, the channel acts as a
+ "coincidence detector" and only once both of these conditions are met, the
+ channel opens and it allows positively charged ions (cations) to flow through
+ the cell membrane [2]_.
+
+ If we make the approximation that the magnesium block changes
+ instantaneously with voltage and is independent of the gating of the channel,
+ the net NMDA receptor-mediated synaptic current is given by
+
+ .. math::
+
+ I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty}
+
+ where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the
+ reversal potential.
+
+ Simultaneously, the kinetics of synaptic state :math:`g` is determined by a 2nd-order kinetics [1]_:
+
+ .. math::
+
+ & \frac{d g}{dt} = \alpha_1 x (1 - g) - \beta_1 g \\
+ & \frac{d x}{dt} = \alpha_2 [T] (1 - x) - \beta_2 x
+
+ where :math:`\alpha_1, \beta_1` refers to the conversion rate of variable g and
+ :math:`\alpha_2, \beta_2` refers to the conversion rate of variable x.
+
+ The NMDA receptor has been thought to be very important for controlling
+ synaptic plasticity and mediating learning and memory functions [3]_.
+
+ .. [1] Devaney A J . Mathematical Foundations of Neuroscience[M].
+ Springer New York, 2010: 162.
+ .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and
+ Eric Gouaux. "Subunit arrangement and function in NMDA receptors."
+ Nature 438, no. 7065 (2005): 185-192.
+ .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New
+ England journal of medicine, 361(3), p.302.
+ .. [4] https://en.wikipedia.org/wiki/NMDA_receptor
+
+
+ Args:
+ alpha1: float, ArrayType, Callable. The conversion rate of g from inactive to active. Default 2 ms^-1.
+ beta1: float, ArrayType, Callable. The conversion rate of g from active to inactive. Default 0.01 ms^-1.
+ alpha2: float, ArrayType, Callable. The conversion rate of x from inactive to active. Default 1 ms^-1.
+ beta2: float, ArrayType, Callable. The conversion rate of x from active to inactive. Default 0.5 ms^-1.
+ T: float, ArrayType, Callable. Transmitter concentration when synapse is
+ triggered by a pre-synaptic spike. Default 1 [mM].
+ T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms]
+ %s
+ """
+ supported_modes = (bm.NonBatchingMode, bm.BatchingMode)
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ sharding: Optional[Sequence[str]] = None,
+ method: str = 'exp_auto',
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+
+ # synapse parameters
+ alpha1: Union[float, ArrayType, Callable] = 2.,
+ beta1: Union[float, ArrayType, Callable] = 0.01,
+ alpha2: Union[float, ArrayType, Callable] = 1.,
+ beta2: Union[float, ArrayType, Callable] = 0.5,
+ T: Union[float, ArrayType, Callable] = 1.,
+ T_dur: Union[float, ArrayType, Callable] = 0.5,
+ ):
+ super().__init__(name=name,
+ mode=mode,
+ size=size,
+ keep_size=keep_size,
+ sharding=sharding)
+
+ # parameters
+ self.beta1 = self.init_param(beta1)
+ self.beta2 = self.init_param(beta2)
+ self.alpha1 = self.init_param(alpha1)
+ self.alpha2 = self.init_param(alpha2)
+ self.T = self.init_param(T)
+ self.T_dur = self.init_param(T_dur)
+
+ # integral
+ self.integral = odeint(method=method, f=JointEq([self.dg, self.dx]))
+
+ self.reset_state(self.mode)
+
+ def reset_state(self, batch_size=None):
+ self.g = self.init_variable(bm.zeros, batch_size)
+ self.x = self.init_variable(bm.zeros, batch_size)
+ self.spike_arrival_time = self.init_variable(bm.ones, batch_size)
+ self.spike_arrival_time.fill(-1e7)
+
+ def dg(self, g, t, x):
+ return self.alpha1 * x * (1 - g) - self.beta1 * g
+
+ def dx(self, x, t, T):
+ return self.alpha2 * T * (1 - x) - self.beta2 * x
+
+ def update(self, pre_spike):
+ t = share.load('t')
+ dt = share.load('dt')
+ self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time)
+ T = ((t - self.spike_arrival_time) < self.T_dur) * self.T
+ self.g.value, self.x.value = self.integral(self.g, self.x, t, T, dt)
+ return self.g.value
+
+ def return_info(self):
+ return self.g
+
+
+BioNMDA.__doc__ = BioNMDA.__doc__ % (pneu_doc,)
diff --git a/brainpy/_src/synapses/delay_couplings.py b/brainpy/_src/dyn/synapses/delay_couplings.py
similarity index 99%
rename from brainpy/_src/synapses/delay_couplings.py
rename to brainpy/_src/dyn/synapses/delay_couplings.py
index c1fd8513b..4ce50c3ee 100644
--- a/brainpy/_src/synapses/delay_couplings.py
+++ b/brainpy/_src/dyn/synapses/delay_couplings.py
@@ -6,7 +6,7 @@
from jax import vmap
import brainpy.math as bm
-from brainpy._src.dynsys import SynConn
+from brainpy._src.dynsys import DynSysGroup as SynConn
from brainpy._src.neurons.input_groups import InputGroup, OutputGroup
from brainpy._src.initialize import Initializer
from brainpy.check import is_sequence
diff --git a/brainpy/_src/synapses/gap_junction.py b/brainpy/_src/dyn/synapses/gap_junction.py
similarity index 94%
rename from brainpy/_src/synapses/gap_junction.py
rename to brainpy/_src/dyn/synapses/gap_junction.py
index b6164da91..c9432d3b0 100644
--- a/brainpy/_src/synapses/gap_junction.py
+++ b/brainpy/_src/dyn/synapses/gap_junction.py
@@ -4,7 +4,7 @@
import brainpy.math as bm
from brainpy._src.connect import TwoEndConnector
-from brainpy._src.dynsys import NeuGroup, TwoEndConn
+from brainpy._src.dynsys import NeuDyn, DynamicalSystem as TwoEndConn
from brainpy._src.initialize import Initializer, parameter
from brainpy.types import ArrayType
@@ -16,8 +16,8 @@
class GapJunction(TwoEndConn):
def __init__(
self,
- pre: NeuGroup,
- post: NeuGroup,
+ pre: NeuDyn,
+ post: NeuDyn,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
comp_method: str = 'dense',
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
diff --git a/brainpy/_src/synapses/tests/test_delay_couplings.py b/brainpy/_src/dyn/synapses/test_delay_couplings.py
similarity index 93%
rename from brainpy/_src/synapses/tests/test_delay_couplings.py
rename to brainpy/_src/dyn/synapses/test_delay_couplings.py
index d94ea89c6..51af9d685 100644
--- a/brainpy/_src/synapses/tests/test_delay_couplings.py
+++ b/brainpy/_src/dyn/synapses/test_delay_couplings.py
@@ -10,6 +10,7 @@
class Test_delay_couplings(parameterized.TestCase):
def test_DiffusiveCoupling(self):
+ bm.random.seed()
areas = bp.rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn1')
conn = bp.synapses.DiffusiveCoupling(areas.x, areas.x, areas.input,
conn_mat=bp.conn.All2All(pre=areas.num, post=areas.num).require('conn_mat'),
@@ -22,8 +23,10 @@ def test_DiffusiveCoupling(self):
inputs=('fhn1.input', 35.))
runner(10.)
self.assertTupleEqual(runner.mon['fhn1.x'].shape, (100, 80))
+ bm.clear_buffer_memory()
def test_AdditiveCoupling(self):
+ bm.random.seed()
areas = bp.rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn2')
conn = bp.synapses.AdditiveCoupling(areas.x, areas.input,
conn_mat=bp.conn.All2All(pre=areas.num, post=areas.num).require('conn_mat'),
@@ -36,3 +39,4 @@ def test_AdditiveCoupling(self):
inputs=('fhn2.input', 35.))
runner(10.)
self.assertTupleEqual(runner.mon['fhn2.x'].shape, (100, 80))
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/synapses/tests/test_gap_junction.py b/brainpy/_src/dyn/synapses/test_gap_junction.py
similarity index 93%
rename from brainpy/_src/synapses/tests/test_gap_junction.py
rename to brainpy/_src/dyn/synapses/test_gap_junction.py
index cd3c00d3a..c3ff9440b 100644
--- a/brainpy/_src/synapses/tests/test_gap_junction.py
+++ b/brainpy/_src/dyn/synapses/test_gap_junction.py
@@ -10,6 +10,7 @@
class Test_gap_junction(parameterized.TestCase):
def test_gap_junction(self):
+ bm.random.seed()
neu = bp.neurons.HH(2, V_initializer=bp.init.Constant(-70.68))
syn = gap_junction.GapJunction(neu, neu, conn=bp.connect.All2All(include_self=False))
net = bp.Network(syn=syn, neu=neu)
@@ -20,3 +21,4 @@ def test_gap_junction(self):
inputs=('neu.input', 35.))
runner(10.)
self.assertTupleEqual(runner.mon['neu.V'].shape, (100, 2))
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/dyn/utils.py b/brainpy/_src/dyn/utils.py
new file mode 100644
index 000000000..0af1d4532
--- /dev/null
+++ b/brainpy/_src/dyn/utils.py
@@ -0,0 +1,16 @@
+from typing import Optional, Union
+import brainpy.math as bm
+
+__all__ = [
+ 'get_spk_type',
+]
+
+
+def get_spk_type(spk_type: Optional[type] = None, mode: Optional[bm.Mode] = None):
+ if mode is None:
+ return bm.bool
+ elif isinstance(mode, bm.TrainingMode):
+ return bm.float_ if (spk_type is None) else spk_type
+ else:
+ assert isinstance(mode, bm.Mode)
+ return bm.bool if (spk_type is None) else spk_type
diff --git a/brainpy/_src/synapses_v2/__init__.py b/brainpy/_src/dynold/__init__.py
similarity index 100%
rename from brainpy/_src/synapses_v2/__init__.py
rename to brainpy/_src/dynold/__init__.py
diff --git a/brainpy/_src/dynold/experimental/__init__.py b/brainpy/_src/dynold/experimental/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/brainpy/_src/synapses_v2/abstract_synapses.py b/brainpy/_src/dynold/experimental/abstract_synapses.py
similarity index 99%
rename from brainpy/_src/synapses_v2/abstract_synapses.py
rename to brainpy/_src/dynold/experimental/abstract_synapses.py
index 16783f18e..c24442461 100644
--- a/brainpy/_src/synapses_v2/abstract_synapses.py
+++ b/brainpy/_src/dynold/experimental/abstract_synapses.py
@@ -7,7 +7,7 @@
import brainpy.math as bm
from brainpy._src.connect import TwoEndConnector, All2All, One2One
from brainpy._src.context import share
-from brainpy._src.synapses_v2.base import SynConnNS, SynOutNS, SynSTPNS
+from brainpy._src.dynold.experimental.base import SynConnNS, SynOutNS, SynSTPNS
from brainpy._src.initialize import Initializer, variable_
from brainpy._src.integrators import odeint, JointEq
from brainpy.check import is_float
diff --git a/brainpy/_src/synapses_v2/base.py b/brainpy/_src/dynold/experimental/base.py
similarity index 96%
rename from brainpy/_src/synapses_v2/base.py
rename to brainpy/_src/dynold/experimental/base.py
index 40010e574..0ff0d6cbc 100644
--- a/brainpy/_src/synapses_v2/base.py
+++ b/brainpy/_src/dynold/experimental/base.py
@@ -5,12 +5,12 @@
import brainpy.math as bm
from brainpy._src.connect import TwoEndConnector, All2All, One2One, MatConn, IJConn
-from brainpy._src.dynsys import DynamicalSystemNS
+from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.initialize import Initializer, parameter
from brainpy.types import ArrayType
-class SynConnNS(DynamicalSystemNS):
+class SynConnNS(DynamicalSystem):
def __init__(
self,
conn: TwoEndConnector,
@@ -118,7 +118,7 @@ def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
return post_vs
-class SynOutNS(DynamicalSystemNS):
+class SynOutNS(DynamicalSystem):
def update(self, post_g, post_v):
raise NotImplementedError
@@ -126,7 +126,7 @@ def reset_state(self, batch_size: Optional[int] = None):
pass
-class SynSTPNS(DynamicalSystemNS):
+class SynSTPNS(DynamicalSystem):
"""Base class for synaptic short-term plasticity."""
def update(self, pre_spike):
diff --git a/brainpy/_src/synapses_v2/others.py b/brainpy/_src/dynold/experimental/others.py
similarity index 96%
rename from brainpy/_src/synapses_v2/others.py
rename to brainpy/_src/dynold/experimental/others.py
index 0dfb2b105..9bd6d1fac 100644
--- a/brainpy/_src/synapses_v2/others.py
+++ b/brainpy/_src/dynold/experimental/others.py
@@ -2,12 +2,12 @@
from typing import Union, Optional
import brainpy.math as bm
-from brainpy._src.dynsys import DynamicalSystemNS
+from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.context import share
from brainpy.check import is_float, is_integer
-class PoissonInput(DynamicalSystemNS):
+class PoissonInput(DynamicalSystem):
"""Poisson Input.
Adds independent Poisson input to a target variable. For large
diff --git a/brainpy/_src/synapses_v2/syn_outs.py b/brainpy/_src/dynold/experimental/syn_outs.py
similarity index 97%
rename from brainpy/_src/synapses_v2/syn_outs.py
rename to brainpy/_src/dynold/experimental/syn_outs.py
index 5492513da..10f3277ec 100644
--- a/brainpy/_src/synapses_v2/syn_outs.py
+++ b/brainpy/_src/dynold/experimental/syn_outs.py
@@ -2,7 +2,7 @@
from typing import Union
-from brainpy._src.synapses_v2.base import SynOutNS
+from brainpy._src.dynold.experimental.base import SynOutNS
from brainpy.math import exp
from brainpy.types import ArrayType
diff --git a/brainpy/_src/synapses_v2/syn_plasticity.py b/brainpy/_src/dynold/experimental/syn_plasticity.py
similarity index 98%
rename from brainpy/_src/synapses_v2/syn_plasticity.py
rename to brainpy/_src/dynold/experimental/syn_plasticity.py
index 384dbafef..e5570c2b2 100644
--- a/brainpy/_src/synapses_v2/syn_plasticity.py
+++ b/brainpy/_src/dynold/experimental/syn_plasticity.py
@@ -4,9 +4,9 @@
import jax.numpy as jnp
-from brainpy._src.context import share
from brainpy import math as bm, tools
-from brainpy._src.synapses_v2.base import SynSTPNS
+from brainpy._src.context import share
+from brainpy._src.dynold.experimental.base import SynSTPNS
from brainpy._src.initialize import variable_, OneInit, parameter
from brainpy._src.integrators import odeint, JointEq
from brainpy.types import ArrayType, Shape
diff --git a/brainpy/_src/neurons/__init__.py b/brainpy/_src/dynold/neurons/__init__.py
similarity index 68%
rename from brainpy/_src/neurons/__init__.py
rename to brainpy/_src/dynold/neurons/__init__.py
index 8b9540ab6..e4e413d69 100644
--- a/brainpy/_src/neurons/__init__.py
+++ b/brainpy/_src/dynold/neurons/__init__.py
@@ -3,5 +3,3 @@
from .biological_models import *
from .fractional_models import *
from .reduced_models import *
-from .input_groups import *
-from .noise_groups import *
diff --git a/brainpy/_src/neurons/biological_models.py b/brainpy/_src/dynold/neurons/biological_models.py
similarity index 71%
rename from brainpy/_src/neurons/biological_models.py
rename to brainpy/_src/dynold/neurons/biological_models.py
index 9c533012f..2adad502c 100644
--- a/brainpy/_src/neurons/biological_models.py
+++ b/brainpy/_src/dynold/neurons/biological_models.py
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
-from typing import Union, Callable, Optional
+from typing import Union, Callable
import brainpy.math as bm
from brainpy import check
-from brainpy._src.dynsys import NeuGroupNS
from brainpy._src.context import share
+from brainpy._src.dyn.neurons import hh
+from brainpy._src.dynsys import NeuDyn
from brainpy._src.initialize import (OneInit,
- Uniform,
Initializer,
parameter,
noise as init_noise,
@@ -25,7 +25,7 @@
]
-class HH(NeuGroupNS):
+class HH(hh.HH):
r"""Hodgkin–Huxley neuron model.
**Model Descriptions**
@@ -198,137 +198,32 @@ class HH(NeuGroupNS):
"""
def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- ENa: Union[float, ArrayType, Initializer, Callable] = 50.,
- gNa: Union[float, ArrayType, Initializer, Callable] = 120.,
- EK: Union[float, ArrayType, Initializer, Callable] = -77.,
- gK: Union[float, ArrayType, Initializer, Callable] = 36.,
- EL: Union[float, ArrayType, Initializer, Callable] = -54.387,
- gL: Union[float, ArrayType, Initializer, Callable] = 0.03,
- V_th: Union[float, ArrayType, Initializer, Callable] = 20.,
- C: Union[float, ArrayType, Initializer, Callable] = 1.0,
- V_initializer: Union[Initializer, Callable, ArrayType] = Uniform(-70, -60.),
- m_initializer: Optional[Union[Initializer, Callable, ArrayType]] = None,
- h_initializer: Optional[Union[Initializer, Callable, ArrayType]] = None,
- n_initializer: Optional[Union[Initializer, Callable, ArrayType]] = None,
- noise: Union[float, ArrayType, Initializer, Callable] = None,
- method: str = 'exp_auto',
- name: str = None,
- input_var: bool = True,
-
- # training parameter
- mode: bm.Mode = None,
+ self, *args, input_var: bool = True, **kwargs,
):
- # initialization
- super(HH, self).__init__(size=size,
- keep_size=keep_size,
- name=name,
- mode=mode)
- assert self.mode.is_one_of(bm.BatchingMode, bm.NonBatchingMode)
-
- # parameters
- self.ENa = parameter(ENa, self.varshape, allow_none=False)
- self.EK = parameter(EK, self.varshape, allow_none=False)
- self.EL = parameter(EL, self.varshape, allow_none=False)
- self.gNa = parameter(gNa, self.varshape, allow_none=False)
- self.gK = parameter(gK, self.varshape, allow_none=False)
- self.gL = parameter(gL, self.varshape, allow_none=False)
- self.C = parameter(C, self.varshape, allow_none=False)
- self.V_th = parameter(V_th, self.varshape, allow_none=False)
- self.noise = init_noise(noise, self.varshape, num_vars=4)
self.input_var = input_var
-
- # initializers
- check.is_initializer(m_initializer, 'm_initializer', allow_none=True)
- check.is_initializer(h_initializer, 'h_initializer', allow_none=True)
- check.is_initializer(n_initializer, 'n_initializer', allow_none=True)
- check.is_initializer(V_initializer, 'V_initializer', allow_none=False)
- self._m_initializer = m_initializer
- self._h_initializer = h_initializer
- self._n_initializer = n_initializer
- self._V_initializer = V_initializer
-
- # integral
- if self.noise is None:
- self.integral = odeint(method=method, f=self.derivative)
- else:
- self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
-
- # model
+ super().__init__(*args, **kwargs, init_var=False)
self.reset_state(self.mode)
- # m channel
- m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
- m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18)
- m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V))
- dm = lambda self, m, t, V: self.m_alpha(V) * (1 - m) - self.m_beta(V) * m
-
- # h channel
- h_alpha = lambda self, V: 0.07 * bm.exp(-(V + 65) / 20.)
- h_beta = lambda self, V: 1 / (1 + bm.exp(-(V + 35) / 10))
- h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V))
- dh = lambda self, h, t, V: self.h_alpha(V) * (1 - h) - self.h_beta(V) * h
-
- # n channel
- n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
- n_beta = lambda self, V: 0.125 * bm.exp(-(V + 65) / 80)
- n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V))
- dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n
-
def reset_state(self, batch_size=None):
- self.V = variable_(self._V_initializer, self.varshape, batch_size)
- if self._m_initializer is None:
- self.m = bm.Variable(self.m_inf(self.V.value), batch_axis=self.V.batch_axis)
- else:
- self.m = variable_(self._m_initializer, self.varshape, batch_size)
- if self._h_initializer is None:
- self.h = bm.Variable(self.h_inf(self.V.value), batch_axis=self.V.batch_axis)
- else:
- self.h = variable_(self._h_initializer, self.varshape, batch_size)
- if self._n_initializer is None:
- self.n = bm.Variable(self.n_inf(self.V.value), batch_axis=self.V.batch_axis)
- else:
- self.n = variable_(self._n_initializer, self.varshape, batch_size)
+ super().reset_state(batch_size)
if self.input_var:
self.input = variable_(bm.zeros, self.varshape, batch_size)
- self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)
-
- def dV(self, V, t, m, h, n, I):
- I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
- I_K = (self.gK * n ** 4.0) * (V - self.EK)
- I_leak = self.gL * (V - self.EL)
- dVdt = (- I_Na - I_K - I_leak + I) / self.C
- return dVdt
-
- @property
- def derivative(self):
- return JointEq(self.dV, self.dm, self.dh, self.dn)
def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
if self.input_var:
if x is not None:
self.input += x
x = self.input.value
else:
x = 0. if x is None else x
- V, m, h, n = self.integral(self.V.value, self.m.value, self.h.value, self.n.value, t, x, dt)
- self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
- self.V.value = V
- self.m.value = m
- self.h.value = h
- self.n.value = n
- return self.spike.value
+ return super().update(x)
def clear_input(self):
if self.input_var:
- self.input[:] = 0.
+ self.input.value = bm.zeros_like(self.input)
-class MorrisLecar(NeuGroupNS):
+class MorrisLecar(hh.MorrisLecar):
r"""The Morris-Lecar neuron model.
**Model Descriptions**
@@ -403,116 +298,32 @@ class MorrisLecar(NeuGroupNS):
"""
def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- V_Ca: Union[float, ArrayType, Initializer, Callable] = 130.,
- g_Ca: Union[float, ArrayType, Initializer, Callable] = 4.4,
- V_K: Union[float, ArrayType, Initializer, Callable] = -84.,
- g_K: Union[float, ArrayType, Initializer, Callable] = 8.,
- V_leak: Union[float, ArrayType, Initializer, Callable] = -60.,
- g_leak: Union[float, ArrayType, Initializer, Callable] = 2.,
- C: Union[float, ArrayType, Initializer, Callable] = 20.,
- V1: Union[float, ArrayType, Initializer, Callable] = -1.2,
- V2: Union[float, ArrayType, Initializer, Callable] = 18.,
- V3: Union[float, ArrayType, Initializer, Callable] = 2.,
- V4: Union[float, ArrayType, Initializer, Callable] = 30.,
- phi: Union[float, ArrayType, Initializer, Callable] = 0.04,
- V_th: Union[float, ArrayType, Initializer, Callable] = 10.,
- W_initializer: Union[Callable, Initializer, ArrayType] = OneInit(0.02),
- V_initializer: Union[Callable, Initializer, ArrayType] = Uniform(-70., -60.),
- noise: Union[float, ArrayType, Initializer, Callable] = None,
- method: str = 'exp_auto',
- name: str = None,
- input_var: bool = True,
-
- # training parameter
- mode: bm.Mode = None,
+ self, *args, input_var: bool = True, **kwargs,
):
- # initialization
- super(MorrisLecar, self).__init__(size=size,
- keep_size=keep_size,
- name=name,
- mode=mode)
- assert self.mode.is_one_of(bm.BatchingMode, bm.NonBatchingMode)
-
- # params
- self.V_Ca = parameter(V_Ca, self.varshape, allow_none=False)
- self.g_Ca = parameter(g_Ca, self.varshape, allow_none=False)
- self.V_K = parameter(V_K, self.varshape, allow_none=False)
- self.g_K = parameter(g_K, self.varshape, allow_none=False)
- self.V_leak = parameter(V_leak, self.varshape, allow_none=False)
- self.g_leak = parameter(g_leak, self.varshape, allow_none=False)
- self.C = parameter(C, self.varshape, allow_none=False)
- self.V1 = parameter(V1, self.varshape, allow_none=False)
- self.V2 = parameter(V2, self.varshape, allow_none=False)
- self.V3 = parameter(V3, self.varshape, allow_none=False)
- self.V4 = parameter(V4, self.varshape, allow_none=False)
- self.phi = parameter(phi, self.varshape, allow_none=False)
- self.V_th = parameter(V_th, self.varshape, allow_none=False)
- self.noise = init_noise(noise, self.varshape, num_vars=2)
self.input_var = input_var
-
- # initializers
- self._W_initializer = check.is_initializer(W_initializer, allow_none=False)
- self._V_initializer = check.is_initializer(V_initializer, allow_none=False)
-
- # variables
+ super().__init__(*args, **kwargs, init_var=False)
self.reset_state(self.mode)
- # integral
- if self.noise is None:
- self.integral = odeint(method=method, f=self.derivative)
- else:
- self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
-
def reset_state(self, batch_size=None):
- self.W = variable_(self._W_initializer, self.varshape, batch_size)
- self.V = variable_(self._V_initializer, self.varshape, batch_size)
- self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)
+ super().reset_state(batch_size)
if self.input_var:
self.input = variable_(bm.zeros, self.varshape, batch_size)
- def dV(self, V, t, W, I_ext):
- M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2))
- I_Ca = self.g_Ca * M_inf * (V - self.V_Ca)
- I_K = self.g_K * W * (V - self.V_K)
- I_Leak = self.g_leak * (V - self.V_leak)
- dVdt = (- I_Ca - I_K - I_Leak + I_ext) / self.C
- return dVdt
-
- def dW(self, W, t, V):
- tau_W = 1 / (self.phi * bm.cosh((V - self.V3) / (2 * self.V4)))
- W_inf = (1 / 2) * (1 + bm.tanh((V - self.V3) / self.V4))
- dWdt = (W_inf - W) / tau_W
- return dWdt
-
- @property
- def derivative(self):
- return JointEq(self.dV, self.dW)
-
def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
if self.input_var:
if x is not None:
self.input += x
x = self.input.value
else:
x = 0. if x is None else x
- V, W = self.integral(self.V, self.W, t, x, dt)
- spike = bm.logical_and(self.V < self.V_th, V >= self.V_th)
- self.V.value = V
- self.W.value = W
- self.spike.value = spike
- return spike
+ return super().update(x)
def clear_input(self):
if self.input_var:
- self.input[:] = 0.
+ self.input.value = bm.zeros_like(self.input)
-class PinskyRinzelModel(NeuGroupNS):
+class PinskyRinzelModel(NeuDyn):
r"""The Pinsky and Rinsel (1994) model.
The Pinsky and Rinsel (1994) model [7]_ is a 2-compartment (soma and dendrite),
@@ -661,6 +472,8 @@ class PinskyRinzelModel(NeuGroupNS):
neurophysiology, 66(2), 635-650.
"""
+ supported_modes = (bm.BatchingMode, bm.NonBatchingMode)
+
def __init__(
self,
size: Shape,
@@ -698,7 +511,6 @@ def __init__(
keep_size=keep_size,
name=name,
mode=mode)
- assert self.mode.is_one_of(bm.BatchingMode, bm.NonBatchingMode)
# conductance parameters
self.gAHP = parameter(gAHP, self.varshape, allow_none=False)
@@ -800,7 +612,7 @@ def dVd(self, Vd, t, s, q, c, Ca, Vs):
@property
def derivative(self):
- return JointEq([self.dVs, self.dVd, self.dCa, self.dh, self.dn, self.ds, self.dc, self.dq])
+ return JointEq(self.dVs, self.dVd, self.dCa, self.dh, self.dn, self.ds, self.dc, self.dq)
def update(self, x=None):
assert x is None
@@ -826,8 +638,8 @@ def update(self, x=None):
self.q.value = q
def clear_input(self):
- self.Id[:] = 0.
- self.Is[:] = 0.
+ self.Id.value = bm.zeros_like(self.Id)
+ self.Is.value = bm.zeros_like(self.Is)
def alpha_m(self, Vs):
return 0.32 * (13.1 - (Vs + 60.)) / (bm.exp((13.1 - (Vs + 60.)) / 4.) - 1.)
@@ -899,7 +711,7 @@ def inf_q(self, Ca):
return alpha / (alpha + beta)
-class WangBuzsakiModel(NeuGroupNS):
+class WangBuzsakiModel(hh.WangBuzsakiHH):
r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model.
Each model is described by a single compartment and obeys the current balance equation:
@@ -985,118 +797,26 @@ class WangBuzsakiModel(NeuGroupNS):
"""
def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- ENa: Union[float, ArrayType, Initializer, Callable] = 55.,
- gNa: Union[float, ArrayType, Initializer, Callable] = 35.,
- EK: Union[float, ArrayType, Initializer, Callable] = -90.,
- gK: Union[float, ArrayType, Initializer, Callable] = 9.,
- EL: Union[float, ArrayType, Initializer, Callable] = -65,
- gL: Union[float, ArrayType, Initializer, Callable] = 0.1,
- V_th: Union[float, ArrayType, Initializer, Callable] = 20.,
- phi: Union[float, ArrayType, Initializer, Callable] = 5.0,
- C: Union[float, ArrayType, Initializer, Callable] = 1.0,
- V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-65.),
- h_initializer: Union[Initializer, Callable, ArrayType] = OneInit(0.6),
- n_initializer: Union[Initializer, Callable, ArrayType] = OneInit(0.32),
- noise: Union[float, ArrayType, Initializer, Callable] = None,
- method: str = 'exp_auto',
- input_var: bool = True,
- name: str = None,
- mode: bm.Mode = None,
+ self, *args, input_var: bool = True, **kwargs,
):
- # initialization
- super(WangBuzsakiModel, self).__init__(size=size, keep_size=keep_size, name=name, mode=mode)
- assert self.mode.is_one_of(bm.BatchingMode, bm.NonBatchingMode)
-
- # parameters
- self.ENa = parameter(ENa, self.varshape, allow_none=False)
- self.EK = parameter(EK, self.varshape, allow_none=False)
- self.EL = parameter(EL, self.varshape, allow_none=False)
- self.gNa = parameter(gNa, self.varshape, allow_none=False)
- self.gK = parameter(gK, self.varshape, allow_none=False)
- self.gL = parameter(gL, self.varshape, allow_none=False)
- self.C = parameter(C, self.varshape, allow_none=False)
- self.phi = parameter(phi, self.varshape, allow_none=False)
- self.V_th = parameter(V_th, self.varshape, allow_none=False)
- self.noise = init_noise(noise, self.varshape, num_vars=3)
self.input_var = input_var
-
- # initializers
- check.is_initializer(h_initializer, 'h_initializer', allow_none=False)
- check.is_initializer(n_initializer, 'n_initializer', allow_none=False)
- check.is_initializer(V_initializer, 'V_initializer', allow_none=False)
- self._h_initializer = h_initializer
- self._n_initializer = n_initializer
- self._V_initializer = V_initializer
-
- # variables
- self.h = variable_(self._h_initializer, self.varshape, self.mode)
- self.n = variable_(self._n_initializer, self.varshape, self.mode)
- self.V = variable_(self._V_initializer, self.varshape, self.mode)
- self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode)
- if self.input_var:
- self.input = variable_(bm.zeros, self.varshape, self.mode)
-
- # integral
- if self.noise is None:
- self.integral = odeint(method=method, f=self.derivative)
- else:
- self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
+ super().__init__(*args, **kwargs, init_var=False)
+ self.reset_state(self.mode)
def reset_state(self, batch_size=None):
- self.h.value = variable_(self._h_initializer, self.varshape, batch_size)
- self.n.value = variable_(self._n_initializer, self.varshape, batch_size)
- self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
- self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)
+ super().reset_state(batch_size)
if self.input_var:
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
- def m_inf(self, V):
- alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)
- beta = 4. * bm.exp(-(V + 60.) / 18.)
- return alpha / (alpha + beta)
-
- def dh(self, h, t, V):
- alpha = 0.07 * bm.exp(-(V + 58) / 20)
- beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1)
- dhdt = alpha * (1 - h) - beta * h
- return self.phi * dhdt
-
- def dn(self, n, t, V):
- alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1)
- beta = 0.125 * bm.exp(-(V + 44) / 80)
- dndt = alpha * (1 - n) - beta * n
- return self.phi * dndt
-
- def dV(self, V, t, h, n, I_ext):
- INa = self.gNa * self.m_inf(V) ** 3 * h * (V - self.ENa)
- IK = self.gK * n ** 4 * (V - self.EK)
- IL = self.gL * (V - self.EL)
- dVdt = (- INa - IK - IL + I_ext) / self.C
- return dVdt
-
- @property
- def derivative(self):
- return JointEq(self.dV, self.dh, self.dn)
-
def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
if self.input_var:
if x is not None:
self.input += x
x = self.input.value
else:
x = 0. if x is None else x
- V, h, n = self.integral(self.V, self.h, self.n, t, x, dt)
- self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
- self.V.value = V
- self.h.value = h
- self.n.value = n
- return self.spike.value
+ return super().update(x)
def clear_input(self):
if self.input_var:
- self.input[:] = 0.
+ self.input.value = bm.zeros_like(self.input)
diff --git a/brainpy/_src/neurons/fractional_models.py b/brainpy/_src/dynold/neurons/fractional_models.py
similarity index 98%
rename from brainpy/_src/neurons/fractional_models.py
rename to brainpy/_src/dynold/neurons/fractional_models.py
index 0bde9b4d5..09babeb78 100644
--- a/brainpy/_src/neurons/fractional_models.py
+++ b/brainpy/_src/dynold/neurons/fractional_models.py
@@ -3,9 +3,10 @@
from typing import Union, Sequence, Callable
import jax.numpy as jnp
+
import brainpy.math as bm
-from brainpy._src.dynsys import NeuGroupNS
from brainpy._src.context import share
+from brainpy._src.dynsys import NeuDyn
from brainpy._src.initialize import ZeroInit, OneInit, Initializer, parameter
from brainpy._src.integrators.fde import CaputoL1Schema
from brainpy._src.integrators.fde import GLShortMemory
@@ -20,7 +21,7 @@
]
-class FractionalNeuron(NeuGroupNS):
+class FractionalNeuron(NeuDyn):
"""Fractional-order neuron model."""
pass
@@ -318,15 +319,13 @@ def derivative(self):
return JointEq(self.dV, self.du)
def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
if self.input_var:
if x is not None:
self.input += x
x = self.input.value
else:
x = 0. if x is None else x
- V, u = self.integral(self.V, self.u, t=t, I_ext=x, dt=dt)
+ V, u = self.integral(self.V, self.u, t=share['t'], I_ext=x, dt=share['dt'])
spikes = V >= self.V_th
self.V.value = jnp.where(spikes, self.c, V)
self.u.value = jnp.where(spikes, u + self.d, u)
diff --git a/brainpy/_src/neurons/reduced_models.py b/brainpy/_src/dynold/neurons/reduced_models.py
similarity index 61%
rename from brainpy/_src/neurons/reduced_models.py
rename to brainpy/_src/dynold/neurons/reduced_models.py
index 018d54aaa..a0c42141d 100644
--- a/brainpy/_src/neurons/reduced_models.py
+++ b/brainpy/_src/dynold/neurons/reduced_models.py
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
-from functools import partial
-from typing import Union, Callable, Optional
+from typing import Union, Callable
from jax.lax import stop_gradient
import brainpy.math as bm
-from brainpy._src.dynsys import NeuGroupNS
from brainpy._src.context import share
+from brainpy._src.dyn.neurons import lif
+from brainpy._src.dynsys import NeuDyn
from brainpy._src.initialize import (ZeroInit,
OneInit,
Initializer,
@@ -33,159 +33,7 @@
]
-class Leaky(NeuGroupNS):
- r"""Leaky Integrator Model.
-
- **Model Descriptions**
-
- This class implements a leaky model, in which its dynamics is
- given by:
-
- .. math::
-
- x(t + \Delta t) = \exp{-1/\tau \Delta t} x(t) + I
-
- Parameters
- ----------
- size: sequence of int, int
- The size of the neuron group.
- tau: float, ArrayType, Initializer, callable
- Membrane time constant.
- method: str
- The numerical integration method.
- name: str
- The group name.
- """
-
- def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- tau: Union[float, ArrayType, Initializer, Callable] = 10.,
- name: str = None,
- mode: bm.Mode = None,
- method: str = 'exp_auto',
- ):
- super().__init__(size=size,
- mode=mode,
- keep_size=keep_size,
- name=name)
- assert self.mode.is_parent_of(bm.TrainingMode, bm.NonBatchingMode)
-
- # parameters
- self.tau = parameter(tau, self.varshape, allow_none=False)
-
- # integral
- self.integral = odeint(method=method, f=self.derivative)
-
- # variables
- self.reset_state(self.mode)
-
- def derivative(self, x, t):
- return -x / self.tau
-
- def reset_state(self, batch_size=None):
- self.x = variable_(bm.zeros, self.varshape, batch_size)
-
- def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
- r = self.integral(self.x.value, t, dt)
- if x is not None:
- r += x
- self.x.value = r
- return r
-
-
-class Integrator(NeuGroupNS):
- r"""Integrator Model.
-
- This class implements an integrator model, in which its dynamics is
- given by:
-
- .. math::
-
- \tau \frac{dx}{dt} = - x(t) + I(t)
-
- where :math:`x` is the integrator value, and :math:`\tau` is the time constant.
-
- Parameters
- ----------
- size: sequence of int, int
- The size of the neuron group.
- tau: float, ArrayType, Initializer, callable
- Membrane time constant.
- x_initializer: ArrayType, Initializer, callable
- The initializer of :math:`x`.
- noise: ArrayType, Initializer, callable
- The noise added onto the membrane potential
- method: str
- The numerical integration method.
- name: str
- The group name.
- """
-
- def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- tau: Union[float, ArrayType, Initializer, Callable] = 10.,
- x_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),
- noise: Union[float, ArrayType, Initializer, Callable] = None,
- input_var: bool = False,
- name: str = None,
- mode: bm.Mode = None,
- method: str = 'exp_auto',
- ):
- super().__init__(size=size,
- mode=mode,
- keep_size=keep_size,
- name=name)
- is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode))
-
- # parameters
- self.tau = parameter(tau, self.varshape, allow_none=False)
- self.noise = init_noise(noise, self.varshape)
- self.input_var = input_var
-
- # initializers
- self._x_initializer = is_initializer(x_initializer)
-
- # integral
- if self.noise is None:
- self.integral = odeint(method=method, f=self.derivative)
- else:
- self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
-
- # variables
- self.reset_state(self.mode)
-
- def derivative(self, V, t, I_ext):
- return (-V + I_ext) / self.tau
-
- def reset_state(self, batch_size=None):
- self.x = variable_(self._x_initializer, self.varshape, batch_size)
- if self.input_var:
- self.input = variable_(bm.zeros, self.varshape, batch_size)
-
- def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
- if self.input_var:
- if x is not None:
- self.input += x
- x = self.input.value
- else:
- x = 0. if x is None else x
- self.x.value = self.integral(self.x.value, t, I_ext=x, dt=dt)
- return self.x.value
-
- def clear_input(self):
- if self.input_var:
- self.input[:] = 0.
-
-
-class LeakyIntegrator(NeuGroupNS):
+class LeakyIntegrator(NeuDyn):
r"""Leaky Integrator Model.
**Model Descriptions**
@@ -291,7 +139,7 @@ def clear_input(self):
self.input[:] = 0.
-class LIF(NeuGroupNS):
+class LIF(lif.LifRef):
r"""Leaky integrate-and-fire neuron model.
**Model Descriptions**
@@ -348,134 +196,32 @@ class LIF(NeuGroupNS):
"""
def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
-
- # neuron parameter
- V_rest: Union[float, ArrayType, Initializer, Callable] = 0.,
- V_reset: Union[float, ArrayType, Initializer, Callable] = -5.,
- V_th: Union[float, ArrayType, Initializer, Callable] = 20.,
- R: Union[float, ArrayType, Initializer, Callable] = 1.,
- tau: Union[float, ArrayType, Initializer, Callable] = 10.,
- tau_ref: Optional[Union[float, ArrayType, Initializer, Callable]] = None,
- V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),
- noise: Optional[Union[float, ArrayType, Initializer, Callable]] = None,
-
- # training parameter
- mode: Optional[bm.Mode] = None,
- spike_fun: Callable = bm.surrogate.inv_square_grad,
-
- # other parameters
- input_var: bool = True,
- ref_var: bool = False,
- method: str = 'exp_auto',
- name: Optional[str] = None,
+ self, *args, input_var: bool = True, **kwargs,
):
- # initialization
- super().__init__(size=size,
- name=name,
- keep_size=keep_size,
- mode=mode)
- is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode), self.name)
-
- # parameters
- self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
- self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
- self.V_th = parameter(V_th, self.varshape, allow_none=False)
- self.tau = parameter(tau, self.varshape, allow_none=False)
- self.R = parameter(R, self.varshape, allow_none=False)
- self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
- self.noise = init_noise(noise, self.varshape)
- self.spike_fun = is_callable(spike_fun, 'spike_fun')
self.input_var = input_var
- self.ref_var = ref_var
-
- # initializers
- self._V_initializer = is_initializer(V_initializer)
-
- # variables
+ super().__init__(*args, **kwargs, init_var=False)
self.reset_state(self.mode)
- # integral
- if self.noise is None:
- self.integral = odeint(method=method, f=self.derivative)
- else:
- self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
-
- def derivative(self, V, t, I):
- return (-V + self.V_rest + self.R * I) / self.tau
-
def reset_state(self, batch_size=None):
- self.V = variable_(self._V_initializer, self.varshape, batch_size)
+ super().reset_state(batch_size)
if self.input_var:
self.input = variable_(bm.zeros, self.varshape, batch_size)
- sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool # the gradient of spike is a float
- self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
- if self.tau_ref is not None:
- self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size)
- if self.ref_var:
- self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)
def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
if self.input_var:
if x is not None:
self.input += x
x = self.input.value
else:
x = 0. if x is None else x
-
- # integrate membrane potential
- V = self.integral(self.V.value, t, x, dt)
-
- if self.tau_ref is not None:
- # refractory
- refractory = (t - self.t_last_spike) <= self.tau_ref
- if isinstance(self.mode, bm.TrainingMode):
- refractory = stop_gradient(refractory)
- V = bm.where(refractory, self.V.value, V)
-
- # spike, refractory, spiking time, and membrane potential reset
- if isinstance(self.mode, bm.TrainingMode):
- spike = self.spike_fun(V - self.V_th)
- spike_no_grad = stop_gradient(spike)
- V += (self.V_reset - V) * spike_no_grad
- spike_ = spike_no_grad > 0.
- # will be used in other place, like Delta Synapse, so stop its gradient
- if self.ref_var:
- self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value)
- t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value))
- else:
- spike = V >= self.V_th
- V = bm.where(spike, self.V_reset, V)
- if self.ref_var:
- self.refractory.value = bm.logical_or(refractory, spike)
- t_last_spike = bm.where(spike, t, self.t_last_spike.value)
- self.V.value = V
- self.spike.value = spike
- self.t_last_spike.value = t_last_spike
-
- else:
- # spike, spiking time, and membrane potential reset
- if isinstance(self.mode, bm.TrainingMode):
- spike = self.spike_fun(V - self.V_th)
- spike_no_grad = stop_gradient(spike)
- V += (self.V_reset - V) * spike_no_grad
- else:
- spike = V >= self.V_th
- V = bm.where(spike, self.V_reset, V)
- self.V.value = V
- self.spike.value = spike
- return spike
+ return super().update(x)
def clear_input(self):
if self.input_var:
- self.input[:] = 0.
+ self.input.value = bm.zeros_like(self.input)
-class ExpIF(NeuGroupNS):
+class ExpIF(lif.ExpIFRef):
r"""Exponential integrate-and-fire neuron model.
**Model Descriptions**
@@ -574,128 +320,32 @@ class ExpIF(NeuGroupNS):
"""
def __init__(
- self,
- size: Shape,
- V_rest: Union[float, ArrayType, Initializer, Callable] = -65.,
- V_reset: Union[float, ArrayType, Initializer, Callable] = -68.,
- V_th: Union[float, ArrayType, Initializer, Callable] = -30.,
- V_T: Union[float, ArrayType, Initializer, Callable] = -59.9,
- delta_T: Union[float, ArrayType, Initializer, Callable] = 3.48,
- R: Union[float, ArrayType, Initializer, Callable] = 1.,
- tau: Union[float, ArrayType, Initializer, Callable] = 10.,
- tau_ref: Union[float, ArrayType, Initializer, Callable] = None,
- V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),
- noise: Union[float, ArrayType, Initializer, Callable] = None,
- spike_fun: Callable = bm.surrogate.inv_square_grad,
- keep_size: bool = False,
- input_var: bool = True,
- ref_var: bool = False,
- mode: bm.Mode = None,
- method: str = 'exp_auto',
- name: str = None
+ self, *args, input_var: bool = True, **kwargs,
):
- # initialize
- super(ExpIF, self).__init__(size=size,
- name=name,
- mode=mode,
- keep_size=keep_size, )
- is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode))
-
- # parameters
- self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
- self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
- self.V_th = parameter(V_th, self.varshape, allow_none=False)
- self.V_T = parameter(V_T, self.varshape, allow_none=False)
- self.delta_T = parameter(delta_T, self.varshape, allow_none=False)
- self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
- self.tau = parameter(tau, self.varshape, allow_none=False)
- self.R = parameter(R, self.varshape, allow_none=False)
- self.noise = init_noise(noise, self.varshape)
self.input_var = input_var
- self.ref_var = ref_var
-
- # initializers
- self._V_initializer = is_initializer(V_initializer)
-
- # training setting
- self.spike_fun = is_callable(spike_fun, 'spike_fun')
-
- # variables
+ super().__init__(*args, **kwargs, init_var=False)
self.reset_state(self.mode)
- # integral
- if self.noise is None:
- self.integral = odeint(method=method, f=self.derivative)
- else:
- self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
-
def reset_state(self, batch_size=None):
- self.V = variable_(self._V_initializer, self.varshape, batch_size)
+ super().reset_state(batch_size)
if self.input_var:
self.input = variable_(bm.zeros, self.varshape, batch_size)
- sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool
- self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
- if self.tau_ref is not None:
- self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size)
- if self.ref_var:
- self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)
-
- def derivative(self, V, t, I_ext):
- exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
- dvdt = (- (V - self.V_rest) + exp_v + self.R * I_ext) / self.tau
- return dvdt
def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
if self.input_var:
if x is not None:
self.input += x
x = self.input.value
else:
x = 0. if x is None else x
-
- V = self.integral(self.V.value, t, x, dt)
-
- if self.tau_ref is not None:
- refractory = (t - self.t_last_spike) <= self.tau_ref
- if isinstance(self.mode, bm.TrainingMode):
- refractory = stop_gradient(refractory)
- V = bm.where(refractory, self.V.value, V)
-
- if isinstance(self.mode, bm.TrainingMode):
- spike = self.spike_fun(V - self.V_th)
- spike_no_grad = stop_gradient(spike)
- V += (self.V_reset - V) * spike_no_grad
- spike_ = spike_no_grad > 0.
- self.t_last_spike.value = stop_gradient(bm.where(spike_, t, self.t_last_spike.value))
- if self.ref_var:
- # will be used in other place, like Delta Synapse, so stop its gradient
- self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value)
- else:
- spike = self.V_th <= V
- V = bm.where(spike, self.V_reset, V)
- self.t_last_spike.value = bm.where(spike, t, self.t_last_spike)
- if self.ref_var:
- self.refractory.value = bm.logical_or(refractory, spike)
- else:
- if isinstance(self.mode, bm.TrainingMode):
- spike = self.spike_fun(V - self.V_th)
- spike_no_grad = stop_gradient(spike)
- V += (self.V_reset - V) * spike_no_grad
- else:
- spike = self.V_th <= V
- V = bm.where(spike, self.V_reset, V)
- self.V.value = V
- self.spike.value = spike
- return spike
+ return super().update(x)
def clear_input(self):
if self.input_var:
- self.input[:] = 0.
+ self.input.value = bm.zeros_like(self.input)
-class AdExIF(NeuGroupNS):
+class AdExIF(lif.AdExIFRef):
r"""Adaptive exponential integrate-and-fire neuron model.
**Model Descriptions**
@@ -771,132 +421,32 @@ class AdExIF(NeuGroupNS):
"""
def __init__(
- self,
- size: Shape,
- V_rest: Union[float, ArrayType, Initializer, Callable] = -65.,
- V_reset: Union[float, ArrayType, Initializer, Callable] = -68.,
- V_th: Union[float, ArrayType, Initializer, Callable] = -30.,
- V_T: Union[float, ArrayType, Initializer, Callable] = -59.9,
- delta_T: Union[float, ArrayType, Initializer, Callable] = 3.48,
- a: Union[float, ArrayType, Initializer, Callable] = 1.,
- b: Union[float, ArrayType, Initializer, Callable] = 1.,
- tau: Union[float, ArrayType, Initializer, Callable] = 10.,
- tau_w: Union[float, ArrayType, Initializer, Callable] = 30.,
- tau_ref: Optional[Union[float, ArrayType, Initializer, Callable]] = None,
- R: Union[float, ArrayType, Initializer, Callable] = 1.,
- V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),
- w_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),
- noise: Optional[Union[float, ArrayType, Initializer, Callable]] = None,
- spike_fun: Callable = bm.surrogate.inv_square_grad,
- method: str = 'exp_auto',
- keep_size: bool = False,
- input_var: bool = True,
- mode: bm.Mode = None,
- name: Optional[str] = None
+ self, *args, input_var: bool = True, **kwargs,
):
- super(AdExIF, self).__init__(size=size,
- keep_size=keep_size,
- name=name,
- mode=mode, )
- is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode))
-
- # parameters
- self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
- self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
- self.V_th = parameter(V_th, self.varshape, allow_none=False)
- self.V_T = parameter(V_T, self.varshape, allow_none=False)
- self.a = parameter(a, self.varshape, allow_none=False)
- self.b = parameter(b, self.varshape, allow_none=False)
- self.R = parameter(R, self.varshape, allow_none=False)
- self.tau = parameter(tau, self.varshape, allow_none=False)
- self.tau_w = parameter(tau_w, self.varshape, allow_none=False)
- self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
- self.delta_T = parameter(delta_T, self.varshape, allow_none=False)
- self.noise = init_noise(noise, self.varshape, num_vars=2)
self.input_var = input_var
- self.spike_fun = is_callable(spike_fun, 'spike_fun')
-
- # initializers
- self._V_initializer = is_initializer(V_initializer)
- self._w_initializer = is_initializer(w_initializer)
-
- # variables
+ super().__init__(*args, **kwargs, init_var=False)
self.reset_state(self.mode)
- # functions
- if self.noise is None:
- self.integral = odeint(method=method, f=self.derivative)
- else:
- self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
-
def reset_state(self, batch_size=None):
- self.V = variable_(self._V_initializer, self.varshape, batch_size)
- self.w = variable_(self._w_initializer, self.varshape, batch_size)
+ super().reset_state(batch_size)
if self.input_var:
self.input = variable_(bm.zeros, self.varshape, batch_size)
- sp_type = bm.float_ if isinstance(self.mode, bm.BatchingMode) else bool
- self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
- if self.tau_ref is not None:
- self.refractory = variable_(partial(bm.zeros, dtype=bool), self.varshape, batch_size)
- self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e8, self.varshape, batch_size)
-
- def dV(self, V, t, w, I_ext):
- exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
- dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I_ext) / self.tau
- return dVdt
-
- def dw(self, w, t, V):
- dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w
- return dwdt
-
- @property
- def derivative(self):
- return JointEq([self.dV, self.dw])
def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
if self.input_var:
if x is not None:
self.input += x
x = self.input.value
else:
x = 0. if x is None else x
-
- V, w = self.integral(self.V.value, self.w.value, t, x, dt)
-
- if self.tau_ref is not None:
- refractory = (t - self.t_last_spike) <= self.tau_ref
- if isinstance(self.mode, bm.TrainingMode):
- refractory = stop_gradient(refractory)
- V = bm.where(refractory, self.V.value, V)
-
- if isinstance(self.mode, bm.TrainingMode):
- spike = self.spike_fun(V - self.V_th)
- spike_no_grad = stop_gradient(spike)
- V += (self.V_reset - V) * spike_no_grad
- w += self.b * spike_no_grad
- spike_ = spike_no_grad > 0.
- if self.tau_ref is not None:
- self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value)
- self.t_last_spike.value = stop_gradient(bm.where(spike_, t, self.t_last_spike.value))
- else:
- spike = V >= self.V_th
- self.V.value = bm.where(spike, self.V_reset, V)
- self.w.value = bm.where(spike, w + self.b, w)
- self.spike.value = spike
- if self.tau_ref is not None:
- self.refractory.value = bm.logical_or(refractory, spike)
- self.t_last_spike.value = bm.where(spike, t, self.t_last_spike.value)
-
- return spike
+ return super().update(x)
def clear_input(self):
if self.input_var:
- self.input[:] = 0.
+ self.input.value = bm.zeros_like(self.input)
-class QuaIF(NeuGroupNS):
+class QuaIF(lif.QuaIFRef):
r"""Quadratic Integrate-and-Fire neuron model.
**Model Descriptions**
@@ -964,119 +514,32 @@ class QuaIF(NeuGroupNS):
"""
def __init__(
- self,
- size: Shape,
- V_rest: Union[float, ArrayType, Initializer, Callable] = -65.,
- V_reset: Union[float, ArrayType, Initializer, Callable] = -68.,
- V_th: Union[float, ArrayType, Initializer, Callable] = -30.,
- V_c: Union[float, ArrayType, Initializer, Callable] = -50.0,
- c: Union[float, ArrayType, Initializer, Callable] = .07,
- R: Union[float, ArrayType, Initializer, Callable] = 1.,
- tau: Union[float, ArrayType, Initializer, Callable] = 10.,
- tau_ref: Union[float, ArrayType, Initializer, Callable] = None,
- V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),
- noise: Union[float, ArrayType, Initializer, Callable] = None,
- spike_fun: Callable = bm.surrogate.inv_square_grad,
- keep_size: bool = False,
- input_var: bool = True,
- mode: bm.Mode = None,
- method: str = 'exp_auto',
- name: str = None
+ self, *args, input_var: bool = True, **kwargs,
):
- # initialization
- super(QuaIF, self).__init__(size=size,
- keep_size=keep_size,
- name=name,
- mode=mode)
- is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode))
-
- # parameters
- self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
- self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
- self.V_th = parameter(V_th, self.varshape, allow_none=False)
- self.V_c = parameter(V_c, self.varshape, allow_none=False)
- self.c = parameter(c, self.varshape, allow_none=False)
- self.R = parameter(R, self.varshape, allow_none=False)
- self.tau = parameter(tau, self.varshape, allow_none=False)
- self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
- self.noise = init_noise(noise, self.varshape, num_vars=1)
self.input_var = input_var
- self.spike_fun = is_callable(spike_fun, 'spike_fun')
-
- # initializers
- self._V_initializer = is_initializer(V_initializer)
-
- # variables
+ super().__init__(*args, **kwargs, init_var=False)
self.reset_state(self.mode)
- # integral
- if self.noise is None:
- self.integral = odeint(method=method, f=self.derivative)
- else:
- self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
-
def reset_state(self, batch_size=None):
- self.V = variable_(self._V_initializer, self.varshape, batch_size)
+ super().reset_state(batch_size)
if self.input_var:
self.input = variable_(bm.zeros, self.varshape, batch_size)
- sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool
- self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
- if self.tau_ref is not None:
- self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size)
- self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)
-
- def derivative(self, V, t, I_ext):
- dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau
- return dVdt
def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
if self.input_var:
if x is not None:
self.input += x
x = self.input.value
else:
x = 0. if x is None else x
-
- V = self.integral(self.V.value, t, x, dt)
-
- if self.tau_ref is not None:
- refractory = (t - self.t_last_spike) <= self.tau_ref
- if isinstance(self.mode, bm.TrainingMode):
- refractory = stop_gradient(refractory)
- V = bm.where(refractory, self.V.value, V)
-
- if isinstance(self.mode, bm.TrainingMode):
- spike = self.spike_fun(V - self.V_th)
- spike_no_grad = stop_gradient(spike)
- V += (self.V_reset - V) * spike_no_grad
- spike_ = spike_no_grad > 0.
- self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value)
- self.t_last_spike.value = stop_gradient(bm.where(spike_, t, self.t_last_spike.value))
- else:
- spike = self.V_th <= V
- t_last_spike = bm.where(spike, t, self.t_last_spike.value)
- V = bm.where(spike, self.V_reset, V)
- self.refractory.value = bm.logical_or(refractory, spike)
- self.t_last_spike.value = t_last_spike
- else:
- if isinstance(self.mode, bm.TrainingMode):
- spike = self.spike_fun(V - self.V_th)
- spike_no_grad = stop_gradient(spike)
- V += (self.V_reset - V) * spike_no_grad
- else:
- spike = self.V_th <= V
- V = bm.where(spike, self.V_reset, V)
- self.V.value = V
- self.spike.value = spike
+ return super().update(x)
def clear_input(self):
if self.input_var:
- self.input[:] = 0.
+ self.input.value = bm.zeros_like(self.input)
-class AdQuaIF(NeuGroupNS):
+class AdQuaIF(lif.AdQuaIFRef):
r"""Adaptive quadratic integrate-and-fire neuron model.
**Model Descriptions**
@@ -1154,110 +617,32 @@ class AdQuaIF(NeuGroupNS):
"""
def __init__(
- self,
- size: Shape,
- V_rest: Union[float, ArrayType, Initializer, Callable] = -65.,
- V_reset: Union[float, ArrayType, Initializer, Callable] = -68.,
- V_th: Union[float, ArrayType, Initializer, Callable] = -30.,
- V_c: Union[float, ArrayType, Initializer, Callable] = -50.0,
- a: Union[float, ArrayType, Initializer, Callable] = 1.,
- b: Union[float, ArrayType, Initializer, Callable] = .1,
- c: Union[float, ArrayType, Initializer, Callable] = .07,
- tau: Union[float, ArrayType, Initializer, Callable] = 10.,
- tau_w: Union[float, ArrayType, Initializer, Callable] = 10.,
- V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),
- w_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),
- noise: Union[float, ArrayType, Initializer, Callable] = None,
- spike_fun: Callable = bm.surrogate.inv_square_grad,
- method: str = 'exp_auto',
- keep_size: bool = False,
- input_var: bool = True,
- mode: bm.Mode = None,
- name: str = None
+ self, *args, input_var: bool = True, **kwargs,
):
- super(AdQuaIF, self).__init__(size=size,
- keep_size=keep_size,
- name=name,
- mode=mode, )
- is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode))
-
- # parameters
- self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
- self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
- self.V_th = parameter(V_th, self.varshape, allow_none=False)
- self.V_c = parameter(V_c, self.varshape, allow_none=False)
- self.c = parameter(c, self.varshape, allow_none=False)
- self.a = parameter(a, self.varshape, allow_none=False)
- self.b = parameter(b, self.varshape, allow_none=False)
- self.tau = parameter(tau, self.varshape, allow_none=False)
- self.tau_w = parameter(tau_w, self.varshape, allow_none=False)
- self.noise = init_noise(noise, self.varshape, num_vars=2)
self.input_var = input_var
- self.spike_fun = is_callable(spike_fun, 'spike_fun')
-
- # initializers
- self._V_initializer = is_initializer(V_initializer)
- self._w_initializer = is_initializer(w_initializer)
-
- # variables
+ super().__init__(*args, **kwargs, init_var=False)
self.reset_state(self.mode)
- # integral
- if self.noise is None:
- self.integral = odeint(method=method, f=self.derivative)
- else:
- self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
-
def reset_state(self, batch_size=None):
- self.V = variable_(self._V_initializer, self.varshape, batch_size)
- self.w = variable_(self._w_initializer, self.varshape, batch_size)
- sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool
- self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
+ super().reset_state(batch_size)
if self.input_var:
self.input = variable_(bm.zeros, self.varshape, batch_size)
- def dV(self, V, t, w, I_ext):
- dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I_ext) / self.tau
- return dVdt
-
- def dw(self, w, t, V):
- dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w
- return dwdt
-
- @property
- def derivative(self):
- return JointEq([self.dV, self.dw])
-
def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
if self.input_var:
if x is not None:
self.input += x
x = self.input.value
else:
x = 0. if x is None else x
-
- V, w = self.integral(self.V.value, self.w.value, t, x, dt)
-
- if isinstance(self.mode, bm.TrainingMode):
- spike = self.spike_fun(V - self.V_th)
- spike_no_grad = stop_gradient(spike)
- V += (self.V_reset - V) * spike_no_grad
- w += self.b * spike_no_grad
- else:
- spike = self.V_th <= V
- self.V.value = bm.where(spike, self.V_reset, V)
- self.w.value = bm.where(spike, w + self.b, w)
- self.spike.value = spike
- return spike
+ return super().update(x)
def clear_input(self):
if self.input_var:
- self.input[:] = 0.
+ self.input.value = bm.zeros_like(self.input)
-class GIF(NeuGroupNS):
+class GIF(lif.GifRef):
r"""Generalized Integrate-and-Fire model.
**Model Descriptions**
@@ -1340,305 +725,32 @@ class GIF(NeuGroupNS):
"""
def __init__(
- self,
- size: Shape,
- V_rest: Union[float, ArrayType, Initializer, Callable] = -70.,
- V_reset: Union[float, ArrayType, Initializer, Callable] = -70.,
- V_th_inf: Union[float, ArrayType, Initializer, Callable] = -50.,
- V_th_reset: Union[float, ArrayType, Initializer, Callable] = -60.,
- R: Union[float, ArrayType, Initializer, Callable] = 20.,
- tau: Union[float, ArrayType, Initializer, Callable] = 20.,
- a: Union[float, ArrayType, Initializer, Callable] = 0.,
- b: Union[float, ArrayType, Initializer, Callable] = 0.01,
- k1: Union[float, ArrayType, Initializer, Callable] = 0.2,
- k2: Union[float, ArrayType, Initializer, Callable] = 0.02,
- R1: Union[float, ArrayType, Initializer, Callable] = 0.,
- R2: Union[float, ArrayType, Initializer, Callable] = 1.,
- A1: Union[float, ArrayType, Initializer, Callable] = 0.,
- A2: Union[float, ArrayType, Initializer, Callable] = 0.,
- V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-70.),
- I1_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),
- I2_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(),
- Vth_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-50.),
- noise: Union[float, ArrayType, Initializer, Callable] = None,
- method: str = 'exp_auto',
- keep_size: bool = False,
- input_var: bool = True,
- name: str = None,
-
- # parameter for training
- mode: bm.Mode = None,
- spike_fun: Callable = bm.surrogate.sigmoid,
+ self, *args, input_var: bool = True, **kwargs,
):
- # initialization
- super().__init__(size=size,
- keep_size=keep_size,
- name=name,
- mode=mode)
- is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode))
-
- # params
- self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
- self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
- self.V_th_inf = parameter(V_th_inf, self.varshape, allow_none=False)
- self.V_th_reset = parameter(V_th_reset, self.varshape, allow_none=False)
- self.R = parameter(R, self.varshape, allow_none=False)
- self.tau = parameter(tau, self.varshape, allow_none=False)
- self.a = parameter(a, self.varshape, allow_none=False)
- self.b = parameter(b, self.varshape, allow_none=False)
- self.k1 = parameter(k1, self.varshape, allow_none=False)
- self.k2 = parameter(k2, self.varshape, allow_none=False)
- self.R1 = parameter(R1, self.varshape, allow_none=False)
- self.R2 = parameter(R2, self.varshape, allow_none=False)
- self.A1 = parameter(A1, self.varshape, allow_none=False)
- self.A2 = parameter(A2, self.varshape, allow_none=False)
- self.noise = init_noise(noise, self.varshape, num_vars=4)
- self.spike_fun = is_callable(spike_fun, 'spike_fun')
self.input_var = input_var
-
- # initializers
- self._V_initializer = is_initializer(V_initializer, 'V_initializer')
- self._I1_initializer = is_initializer(I1_initializer, 'I1_initializer')
- self._I2_initializer = is_initializer(I2_initializer, 'I2_initializer')
- self._Vth_initializer = is_initializer(Vth_initializer, 'Vth_initializer')
-
- # variables
+ super().__init__(*args, **kwargs, init_var=False)
self.reset_state(self.mode)
- # integral
- if self.noise is None:
- self.integral = odeint(method=method, f=self.derivative)
- else:
- self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
-
def reset_state(self, batch_size=None):
- self.V = variable_(self._V_initializer, self.varshape, batch_size)
- self.I1 = variable_(self._I1_initializer, self.varshape, batch_size)
- self.I2 = variable_(self._I2_initializer, self.varshape, batch_size)
- self.V_th = variable_(self._Vth_initializer, self.varshape, batch_size)
+ super().reset_state(batch_size)
if self.input_var:
self.input = variable_(bm.zeros, self.varshape, batch_size)
- sp_type = bm.float_ if self.mode.is_a(bm.TrainingMode) else bool
- self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
-
- def dI1(self, I1, t):
- return - self.k1 * I1
-
- def dI2(self, I2, t):
- return - self.k2 * I2
-
- def dVth(self, V_th, t, V):
- return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf)
-
- def dV(self, V, t, I1, I2, I_ext):
- return (- (V - self.V_rest) + self.R * (I_ext + I1 + I2)) / self.tau
-
- @property
- def derivative(self):
- return JointEq(self.dI1, self.dI2, self.dVth, self.dV)
-
- def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
- if self.input_var:
- if x is not None:
- self.input += x
- x = self.input.value
- else:
- x = 0. if x is None else x
- I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt)
-
- # spike and resets
- if isinstance(self.mode, bm.TrainingMode):
- spike = self.spike_fun(V - self.V_th)
- V += (self.V_reset - V) * spike
- I1 += spike * (self.R1 * I1 + self.A1 - I1)
- I2 += spike * (self.R2 * I2 + self.A2 - I2)
- reset_th = self.spike_fun(self.V_th_reset - V_th) * spike
- V_th += reset_th * (self.V_th_reset - V_th)
- else:
- spike = self.V_th <= V
- V = bm.where(spike, self.V_reset, V)
- I1 = bm.where(spike, self.R1 * I1 + self.A1, I1)
- I2 = bm.where(spike, self.R2 * I2 + self.A2, I2)
- V_th = bm.where(spike, bm.maximum(self.V_th_reset, V_th), V_th)
- self.spike.value = spike
- self.I1.value = I1
- self.I2.value = I2
- self.V_th.value = V_th
- self.V.value = V
- return spike
-
- def clear_input(self):
- if self.input_var:
- self.input[:] = 0.
-
-
-class ALIFBellec2020(NeuGroupNS):
- r"""Leaky Integrate-and-Fire model with SFA [1]_.
-
- This model is similar to the GLIF2 model in the Technical White Paper
- on generalized LIF (GLIF) models from AllenInstitute [2]_.
-
- Formally, this model is given by:
-
- .. math::
-
- \tau \dot{V} = -(V - V_{\mathrm{rest}}) + R*I \\
- \tau_a \dot{a} = -a
-
- Once a spike is induced by :math:`V(t) > V_{\mathrm{th}} + \beta a`, then
-
- .. math::
-
- V \gets V - V_{\mathrm{th}} \\
- a \gets a + 1
-
- References
- ----------
- .. [1] Bellec, Guillaume, et al. "A solution to the learning dilemma for
- recurrent networks of spiking neurons."
- Nature communications 11.1 (2020): 1-15.
- .. [2] Allen Institute: Cell Types Database. © 2018 Allen Institute for
- Brain Science. Allen Cell Types Database, cell feature search.
- Available from: celltypes.brain-map.org/data (2018).
- """
-
- def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
-
- # model parameters
- V_rest: Union[float, ArrayType, Initializer, Callable] = -70.,
- V_th: Union[float, ArrayType, Initializer, Callable] = -60.,
- R: Union[float, ArrayType, Initializer, Callable] = 1.,
- beta: Union[float, ArrayType, Initializer, Callable] = 1.6,
- tau: Union[float, ArrayType, Initializer, Callable] = 20.,
- tau_a: Union[float, ArrayType, Initializer, Callable] = 2000.,
- tau_ref: Union[float, ArrayType, Initializer, Callable] = None,
- noise: Union[float, ArrayType, Initializer, Callable] = None,
-
- # initializers
- V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-70.),
- a_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-50.),
-
- # parameter for training
- spike_fun: Callable = bm.surrogate.relu_grad,
- input_var: bool = True,
-
- # other parameters
- method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
- eprop: bool = False
- ):
- super().__init__(name=name,
- size=size,
- keep_size=keep_size,
- mode=mode)
- is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode))
-
- # parameters
- self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
- self.V_th = parameter(V_th, self.varshape, allow_none=False)
- self.R = parameter(R, self.varshape, allow_none=False)
- self.beta = parameter(beta, self.varshape, allow_none=False)
- self.tau = parameter(tau, self.varshape, allow_none=False)
- self.tau_a = parameter(tau_a, self.varshape, allow_none=False)
- self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
- self.noise = init_noise(noise, self.varshape, num_vars=2)
- self.spike_fun = is_callable(spike_fun, 'spike_fun')
- self.eprop = eprop
- self.input_var = input_var
-
- # initializers
- self._V_initializer = is_initializer(V_initializer, 'V_initializer')
- self._a_initializer = is_initializer(a_initializer, 'a_initializer')
-
- # variables
- self.reset_state(self.mode)
-
- # integral
- if self.noise is None:
- self.integral = odeint(method=method, f=self.derivative)
- else:
- self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
-
- def da(self, a, t):
- return -a / self.tau_a
-
- def dV(self, V, t, I_ext):
- return (- (V - self.V_rest) + self.R * I_ext) / self.tau
-
- @property
- def derivative(self):
- return JointEq([self.dV, self.da])
-
- def reset_state(self, batch_size=None):
- self.a = variable_(self._a_initializer, self.varshape, batch_size)
- self.V = variable_(self._V_initializer, self.varshape, batch_size)
- if self.input_var:
- self.input = variable_(bm.zeros, self.varshape, batch_size)
- sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool
- self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
- if self.tau_ref is not None:
- self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size)
- self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)
-
- def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
+ def update(self, x=None):
if self.input_var:
if x is not None:
self.input += x
x = self.input.value
else:
x = 0. if x is None else x
- V, a = self.integral(self.V, self.a, t, x, dt)
-
- if self.tau_ref is not None:
- # refractory
- refractory = (t - self.t_last_spike) <= self.tau_ref
- if isinstance(self.mode, bm.TrainingMode):
- refractory = stop_gradient(refractory)
- V = bm.where(refractory, self.V.value, V)
- # spike and reset
- if isinstance(self.mode, bm.TrainingMode):
- spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th)
- V -= self.V_th * (stop_gradient(spike) if self.eprop else spike)
- # will be used in other place, like Delta Synapse, so stop its gradient
- spike_ = spike > 0.
- refractory = stop_gradient(bm.logical_or(refractory, spike_))
- t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value))
- else:
- spike = V >= (self.V_th + self.beta * self.a)
- refractory = bm.logical_or(refractory, spike)
- t_last_spike = bm.where(spike, t, self.t_last_spike.value)
- V -= self.V_th * spike
- self.refractory.value = refractory
- self.t_last_spike.value = t_last_spike
-
- else:
- # spike and reset
- if isinstance(self.mode, bm.TrainingMode):
- spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th)
- V -= self.V_th * (stop_gradient(spike) if self.eprop else spike)
- else:
- spike = V >= (self.V_th + self.beta * self.a)
- V -= self.V_th * spike
- self.spike.value = spike
- self.V.value = V
- self.a.value = a + spike
- return spike
+ return super().update(x)
def clear_input(self):
if self.input_var:
- self.input[:] = 0.
+ self.input.value = bm.zeros_like(self.input)
-class Izhikevich(NeuGroupNS):
+class Izhikevich(lif.IzhikevichRef):
r"""The Izhikevich neuron model.
**Model Descriptions**
@@ -1707,137 +819,32 @@ class Izhikevich(NeuGroupNS):
"""
def __init__(
- self,
- size: Shape,
- a: Union[float, ArrayType, Initializer, Callable] = 0.02,
- b: Union[float, ArrayType, Initializer, Callable] = 0.20,
- c: Union[float, ArrayType, Initializer, Callable] = -65.,
- d: Union[float, ArrayType, Initializer, Callable] = 8.,
- V_th: Union[float, ArrayType, Initializer, Callable] = 30.,
- tau_ref: Union[float, ArrayType, Initializer, Callable] = None,
- V_initializer: Union[Initializer, Callable, ArrayType] = None,
- u_initializer: Union[Initializer, Callable, ArrayType] = None,
- noise: Union[float, ArrayType, Initializer, Callable] = None,
- method: str = 'exp_auto',
- mode: bm.Mode = None,
- spike_fun: Callable = bm.surrogate.inv_square_grad,
- keep_size: bool = False,
- input_var: bool = True,
- ref_var: bool = False,
- name: str = None
+ self, *args, input_var: bool = True, **kwargs,
):
- # initialization
- super().__init__(size=size,
- keep_size=keep_size,
- name=name,
- mode=mode)
- is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode))
-
- # params
- self.a = parameter(a, self.varshape, allow_none=False)
- self.b = parameter(b, self.varshape, allow_none=False)
- self.c = parameter(c, self.varshape, allow_none=False)
- self.d = parameter(d, self.varshape, allow_none=False)
- self.V_th = parameter(V_th, self.varshape, allow_none=False)
- self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
- self.noise = init_noise(noise, self.varshape, num_vars=2)
- self.spike_fun = is_callable(spike_fun, 'spike_fun')
self.input_var = input_var
- self.ref_var = ref_var
-
- # initializers
- self._V_initializer = is_initializer(V_initializer, allow_none=True)
- self._u_initializer = is_initializer(u_initializer, allow_none=True)
-
- # variables
+ super().__init__(*args, **kwargs, init_var=False)
self.reset_state(self.mode)
- # functions
- if self.noise is None:
- self.integral = odeint(method=method, f=JointEq([self.dV, self.du]))
- else:
- self.integral = sdeint(method=method, f=JointEq([self.dV, self.du]), g=self.noise)
-
def reset_state(self, batch_size=None):
- v_init = OneInit(-70.) if self._V_initializer is None else self._V_initializer
- self.V = variable_(v_init, self.varshape, batch_size)
- u_init = OneInit(self.b * self.V) if self._u_initializer is None else self._u_initializer
- self.u = variable_(u_init, self.varshape, batch_size)
+ super().reset_state(batch_size)
if self.input_var:
self.input = variable_(bm.zeros, self.varshape, batch_size)
- sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool
- self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
- if self.tau_ref is not None:
- self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size)
- if self.ref_var:
- self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)
-
- def dV(self, V, t, u, I_ext):
- dVdt = 0.04 * V * V + 5 * V + 140 - u + I_ext
- return dVdt
-
- def du(self, u, t, V):
- dudt = self.a * (self.b * V - u)
- return dudt
def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
if self.input_var:
if x is not None:
self.input += x
x = self.input.value
else:
x = 0. if x is None else x
- V, u = self.integral(self.V.value, self.u.value, t, x, dt)
-
- if self.tau_ref is not None:
- refractory = bm.as_jax((t - self.t_last_spike) <= self.tau_ref)
- refractory = stop_gradient(refractory)
- V = bm.where(refractory, self.V.value, V)
-
- # spike, refractory, and reset membrane potential
- if isinstance(self.mode, bm.TrainingMode):
- spike = self.spike_fun(V - self.V_th)
- spike_no_grad = stop_gradient(spike)
- V += spike_no_grad * (self.c - self.V_th)
- u += spike_no_grad * self.d
- t_last_spike = stop_gradient(bm.where(spike_no_grad, t, self.t_last_spike.value))
- if self.ref_var:
- self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_no_grad > 0.))
- else:
- spike = self.V_th <= V
- V = bm.where(spike, self.c, V)
- u = bm.where(spike, u + self.d, u)
- t_last_spike = bm.where(spike, t, self.t_last_spike.value)
- if self.ref_var:
- self.refractory.value = bm.logical_or(refractory, spike)
- self.t_last_spike.value = t_last_spike
-
- else:
- # spike, refractory, and reset membrane potential
- if isinstance(self.mode, bm.TrainingMode):
- spike = self.spike_fun(V - self.V_th)
- spike_no_grad = stop_gradient(spike)
- V += spike_no_grad * (self.c - self.V_th)
- u += spike_no_grad * self.d
- else:
- spike = self.V_th <= V
- V = bm.where(spike, self.c, V)
- u = bm.where(spike, u + self.d, u)
-
- # finally
- self.V.value = V
- self.u.value = u
- self.spike.value = spike
- return spike
+ return super().update(x)
def clear_input(self):
if self.input_var:
- self.input[:] = 0.
+ self.input.value = bm.zeros_like(self.input)
-class HindmarshRose(NeuGroupNS):
+class HindmarshRose(NeuDyn):
r"""Hindmarsh-Rose neuron model.
**Model Descriptions**
@@ -2043,7 +1050,7 @@ def clear_input(self):
self.input[:] = 0.
-class FHN(NeuGroupNS):
+class FHN(NeuDyn):
r"""FitzHugh-Nagumo neuron model.
**Model Descriptions**
@@ -2211,7 +1218,171 @@ def clear_input(self):
self.input[:] = 0.
-class LIF_SFA_Bellec2020(NeuGroupNS):
+class ALIFBellec2020(NeuDyn):
+ r"""Leaky Integrate-and-Fire model with SFA [1]_.
+
+ This model is similar to the GLIF2 model in the Technical White Paper
+ on generalized LIF (GLIF) models from AllenInstitute [2]_.
+
+ Formally, this model is given by:
+
+ .. math::
+
+ \tau \dot{V} = -(V - V_{\mathrm{rest}}) + R*I \\
+ \tau_a \dot{a} = -a
+
+ Once a spike is induced by :math:`V(t) > V_{\mathrm{th}} + \beta a`, then
+
+ .. math::
+
+ V \gets V - V_{\mathrm{th}} \\
+ a \gets a + 1
+
+
+ References
+ ----------
+ .. [1] Bellec, Guillaume, et al. "A solution to the learning dilemma for
+ recurrent networks of spiking neurons."
+ Nature communications 11.1 (2020): 1-15.
+ .. [2] Allen Institute: Cell Types Database. © 2018 Allen Institute for
+ Brain Science. Allen Cell Types Database, cell feature search.
+ Available from: celltypes.brain-map.org/data (2018).
+ """
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+
+ # model parameters
+ V_rest: Union[float, ArrayType, Initializer, Callable] = -70.,
+ V_th: Union[float, ArrayType, Initializer, Callable] = -60.,
+ R: Union[float, ArrayType, Initializer, Callable] = 1.,
+ beta: Union[float, ArrayType, Initializer, Callable] = 1.6,
+ tau: Union[float, ArrayType, Initializer, Callable] = 20.,
+ tau_a: Union[float, ArrayType, Initializer, Callable] = 2000.,
+ tau_ref: Union[float, ArrayType, Initializer, Callable] = None,
+ noise: Union[float, ArrayType, Initializer, Callable] = None,
+
+ # initializers
+ V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-70.),
+ a_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-50.),
+
+ # parameter for training
+ spike_fun: Callable = bm.surrogate.relu_grad,
+ input_var: bool = True,
+
+ # other parameters
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ eprop: bool = False
+ ):
+ super().__init__(name=name,
+ size=size,
+ keep_size=keep_size,
+ mode=mode)
+ is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode))
+
+ # parameters
+ self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
+ self.V_th = parameter(V_th, self.varshape, allow_none=False)
+ self.R = parameter(R, self.varshape, allow_none=False)
+ self.beta = parameter(beta, self.varshape, allow_none=False)
+ self.tau = parameter(tau, self.varshape, allow_none=False)
+ self.tau_a = parameter(tau_a, self.varshape, allow_none=False)
+ self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
+ self.noise = init_noise(noise, self.varshape, num_vars=2)
+ self.spike_fun = is_callable(spike_fun, 'spike_fun')
+ self.eprop = eprop
+ self.input_var = input_var
+
+ # initializers
+ self._V_initializer = is_initializer(V_initializer, 'V_initializer')
+ self._a_initializer = is_initializer(a_initializer, 'a_initializer')
+
+ # variables
+ self.reset_state(self.mode)
+
+ # integral
+ if self.noise is None:
+ self.integral = odeint(method=method, f=self.derivative)
+ else:
+ self.integral = sdeint(method=method, f=self.derivative, g=self.noise)
+
+ def da(self, a, t):
+ return -a / self.tau_a
+
+ def dV(self, V, t, I_ext):
+ return (- (V - self.V_rest) + self.R * I_ext) / self.tau
+
+ @property
+ def derivative(self):
+ return JointEq([self.dV, self.da])
+
+ def reset_state(self, batch_size=None):
+ self.a = variable_(self._a_initializer, self.varshape, batch_size)
+ self.V = variable_(self._V_initializer, self.varshape, batch_size)
+ if self.input_var:
+ self.input = variable_(bm.zeros, self.varshape, batch_size)
+ sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool
+ self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
+ if self.tau_ref is not None:
+ self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size)
+ self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)
+
+ def update(self, x=None):
+ t = share.load('t')
+ dt = share.load('dt')
+ if self.input_var:
+ if x is not None:
+ self.input += x
+ x = self.input.value
+ else:
+ x = 0. if x is None else x
+ V, a = self.integral(self.V, self.a, t, x, dt)
+
+ if self.tau_ref is not None:
+ # refractory
+ refractory = (t - self.t_last_spike) <= self.tau_ref
+ if isinstance(self.mode, bm.TrainingMode):
+ refractory = stop_gradient(refractory)
+ V = bm.where(refractory, self.V.value, V)
+ # spike and reset
+ if isinstance(self.mode, bm.TrainingMode):
+ spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th)
+ V -= self.V_th * (stop_gradient(spike) if self.eprop else spike)
+ # will be used in other place, like Delta Synapse, so stop its gradient
+ spike_ = spike > 0.
+ refractory = stop_gradient(bm.logical_or(refractory, spike_))
+ t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value))
+ else:
+ spike = V >= (self.V_th + self.beta * self.a)
+ refractory = bm.logical_or(refractory, spike)
+ t_last_spike = bm.where(spike, t, self.t_last_spike.value)
+ V -= self.V_th * spike
+ self.refractory.value = refractory
+ self.t_last_spike.value = t_last_spike
+
+ else:
+ # spike and reset
+ if isinstance(self.mode, bm.TrainingMode):
+ spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th)
+ V -= self.V_th * (stop_gradient(spike) if self.eprop else spike)
+ else:
+ spike = V >= (self.V_th + self.beta * self.a)
+ V -= self.V_th * spike
+ self.spike.value = spike
+ self.V.value = V
+ self.a.value = a + spike
+ return spike
+
+ def clear_input(self):
+ if self.input_var:
+ self.input[:] = 0.
+
+
+class LIF_SFA_Bellec2020(NeuDyn):
r"""Leaky Integrate-and-Fire model with SFA [1]_.
This model is similar to the GLIF2 model in the Technical White Paper
diff --git a/brainpy/_src/neurons/tests/test_biological_neurons.py b/brainpy/_src/dynold/neurons/tests/test_biological_neurons.py
similarity index 75%
rename from brainpy/_src/neurons/tests/test_biological_neurons.py
rename to brainpy/_src/dynold/neurons/tests/test_biological_neurons.py
index 94c22a514..907ebfe0a 100644
--- a/brainpy/_src/neurons/tests/test_biological_neurons.py
+++ b/brainpy/_src/dynold/neurons/tests/test_biological_neurons.py
@@ -4,10 +4,12 @@
import brainpy as bp
import brainpy.math as bm
from absl.testing import parameterized
-from brainpy._src.neurons import biological_models
+from brainpy._src.dynold.neurons import biological_models
+
class Test_Biological(parameterized.TestCase):
def test_HH(self):
+ bm.random.seed()
model = biological_models.HH(size=1)
runner = bp.DSRunner(model,
monitors=['V', 'm', 'n', 'h', 'spike'],
@@ -18,8 +20,10 @@ def test_HH(self):
self.assertTupleEqual(runner.mon['n'].shape, (100, 1))
self.assertTupleEqual(runner.mon['h'].shape, (100, 1))
self.assertTupleEqual(runner.mon['spike'].shape, (100, 1))
+ bm.clear_buffer_memory()
def test_HH_with_noise(self):
+ bm.random.seed()
model = biological_models.HH(size=1, noise=0.1)
runner = bp.DSRunner(model,
monitors=['V', 'm', 'n', 'h', 'spike'],
@@ -30,8 +34,10 @@ def test_HH_with_noise(self):
self.assertTupleEqual(runner.mon['n'].shape, (100, 1))
self.assertTupleEqual(runner.mon['h'].shape, (100, 1))
self.assertTupleEqual(runner.mon['spike'].shape, (100, 1))
+ bm.clear_buffer_memory()
def test_HH_batching_mode(self):
+ bm.random.seed()
model = biological_models.HH(size=10, mode=bm.batching_mode)
runner = bp.DSRunner(model,
monitors=['V', 'm', 'n', 'h', 'spike'],
@@ -42,93 +48,112 @@ def test_HH_batching_mode(self):
self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10))
self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10))
self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10))
+ bm.clear_buffer_memory()
def test_MorrisLecar(self):
+ bm.random.seed()
model = biological_models.MorrisLecar(size=1)
runner = bp.DSRunner(model,
- monitors=['V', 'W', 'spike'],
- progress_bar=False)
+ monitors=['V', 'W', 'spike'],
+ progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon['V'].shape, (100, 1))
self.assertTupleEqual(runner.mon['W'].shape, (100, 1))
self.assertTupleEqual(runner.mon['spike'].shape, (100, 1))
+ bm.clear_buffer_memory()
def test_MorrisLecar_with_noise(self):
+ bm.random.seed()
model = biological_models.MorrisLecar(size=1, noise=0.1)
runner = bp.DSRunner(model,
- monitors=['V', 'W', 'spike'],
- progress_bar=False)
+ monitors=['V', 'W', 'spike'],
+ progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon['V'].shape, (100, 1))
self.assertTupleEqual(runner.mon['W'].shape, (100, 1))
self.assertTupleEqual(runner.mon['spike'].shape, (100, 1))
+ bm.clear_buffer_memory()
def test_MorrisLecar_batching_mode(self):
+ bm.random.seed()
model = biological_models.MorrisLecar(size=10, mode=bm.batching_mode)
runner = bp.DSRunner(model,
- monitors=['V', 'W', 'spike'],
- progress_bar=False)
+ monitors=['V', 'W', 'spike'],
+ progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10))
self.assertTupleEqual(runner.mon['W'].shape, (1, 100, 10))
self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10))
+ bm.clear_buffer_memory()
def test_PinskyRinzelModel(self):
+ bm.random.seed()
model = biological_models.PinskyRinzelModel(size=1)
runner = bp.DSRunner(model,
- monitors=['Vs', 'Vd'],
- progress_bar=False)
+ monitors=['Vs', 'Vd'],
+ progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon['Vs'].shape, (100, 1))
self.assertTupleEqual(runner.mon['Vd'].shape, (100, 1))
+ bm.clear_buffer_memory()
def test_PinskyRinzelModel_with_noise(self):
+ bm.random.seed()
model = biological_models.PinskyRinzelModel(size=1, noise=0.1)
runner = bp.DSRunner(model,
- monitors=['Vs', 'Vd'],
- progress_bar=False)
+ monitors=['Vs', 'Vd'],
+ progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon['Vs'].shape, (100, 1))
self.assertTupleEqual(runner.mon['Vd'].shape, (100, 1))
+ bm.clear_buffer_memory()
def test_PinskyRinzelModel_batching_mode(self):
+ bm.random.seed()
model = biological_models.PinskyRinzelModel(size=10, mode=bm.batching_mode)
runner = bp.DSRunner(model,
- monitors=['Vs', 'Vd'],
- progress_bar=False)
+ monitors=['Vs', 'Vd'],
+ progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon['Vs'].shape, (1, 100, 10))
self.assertTupleEqual(runner.mon['Vd'].shape, (1, 100, 10))
+ bm.clear_buffer_memory()
def test_WangBuzsakiModel(self):
+ bm.random.seed()
model = biological_models.WangBuzsakiModel(size=1)
runner = bp.DSRunner(model,
- monitors=['V', 'n', 'h', 'spike'],
- progress_bar=False)
+ monitors=['V', 'n', 'h', 'spike'],
+ progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon['V'].shape, (100, 1))
self.assertTupleEqual(runner.mon['n'].shape, (100, 1))
self.assertTupleEqual(runner.mon['h'].shape, (100, 1))
self.assertTupleEqual(runner.mon['spike'].shape, (100, 1))
+ bm.clear_buffer_memory()
def test_WangBuzsakiModel_with_noise(self):
+ bm.random.seed()
model = biological_models.WangBuzsakiModel(size=1, noise=0.1)
runner = bp.DSRunner(model,
- monitors=['V', 'n', 'h', 'spike'],
- progress_bar=False)
+ monitors=['V', 'n', 'h', 'spike'],
+ progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon['V'].shape, (100, 1))
self.assertTupleEqual(runner.mon['n'].shape, (100, 1))
self.assertTupleEqual(runner.mon['h'].shape, (100, 1))
self.assertTupleEqual(runner.mon['spike'].shape, (100, 1))
+ bm.clear_buffer_memory()
def test_WangBuzsakiModel_batching_mode(self):
+ bm.random.seed()
model = biological_models.WangBuzsakiModel(size=10, mode=bm.batching_mode)
runner = bp.DSRunner(model,
- monitors=['V', 'n', 'h', 'spike'],
- progress_bar=False)
+ monitors=['V', 'n', 'h', 'spike'],
+ progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10))
self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10))
self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10))
- self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10))
\ No newline at end of file
+ self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10))
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/neurons/tests/test_fractional_neurons.py b/brainpy/_src/dynold/neurons/tests/test_fractional_neurons.py
similarity index 80%
rename from brainpy/_src/neurons/tests/test_fractional_neurons.py
rename to brainpy/_src/dynold/neurons/tests/test_fractional_neurons.py
index be7a9f929..9752eaaf1 100644
--- a/brainpy/_src/neurons/tests/test_fractional_neurons.py
+++ b/brainpy/_src/dynold/neurons/tests/test_fractional_neurons.py
@@ -3,11 +3,12 @@
import brainpy as bp
from absl.testing import parameterized
-from brainpy._src.neurons import fractional_models
+from brainpy._src.dynold.neurons import fractional_models
class Test_Fractional(parameterized.TestCase):
def test_FractionalFHR(self):
+ bp.math.random.seed()
model = fractional_models.FractionalFHR(size=1, alpha=0.5)
runner = bp.DSRunner(model,
monitors=['V', 'w', 'y', 'spike'],
@@ -17,8 +18,10 @@ def test_FractionalFHR(self):
self.assertTupleEqual(runner.mon['w'].shape, (100, 1))
self.assertTupleEqual(runner.mon['y'].shape, (100, 1))
self.assertTupleEqual(runner.mon['spike'].shape, (100, 1))
+ bp.math.clear_buffer_memory()
def test_FractionalIzhikevich(self):
+ bp.math.random.seed()
model = fractional_models.FractionalIzhikevich(size=1, alpha=0.5, num_memory=1000)
runner = bp.DSRunner(model,
monitors=['V', 'u', 'spike'],
@@ -26,4 +29,5 @@ def test_FractionalIzhikevich(self):
runner.run(10.)
self.assertTupleEqual(runner.mon['V'].shape, (100, 1))
self.assertTupleEqual(runner.mon['u'].shape, (100, 1))
- self.assertTupleEqual(runner.mon['spike'].shape, (100, 1))
\ No newline at end of file
+ self.assertTupleEqual(runner.mon['spike'].shape, (100, 1))
+ bp.math.clear_buffer_memory()
diff --git a/brainpy/_src/neurons/tests/test_reduced_neurons.py b/brainpy/_src/dynold/neurons/tests/test_reduced_neurons.py
similarity index 92%
rename from brainpy/_src/neurons/tests/test_reduced_neurons.py
rename to brainpy/_src/dynold/neurons/tests/test_reduced_neurons.py
index 279b95d49..f4f411759 100644
--- a/brainpy/_src/neurons/tests/test_reduced_neurons.py
+++ b/brainpy/_src/dynold/neurons/tests/test_reduced_neurons.py
@@ -4,7 +4,8 @@
import brainpy as bp
import brainpy.math as bm
from absl.testing import parameterized
-from brainpy._src.neurons import reduced_models
+from brainpy._src.dynold.neurons import reduced_models
+
class Test_Reduced(parameterized.TestCase):
@parameterized.named_parameters(
@@ -12,6 +13,7 @@ class Test_Reduced(parameterized.TestCase):
for name in reduced_models.__all__
)
def test_run_shape(self, neuron):
+ bm.random.seed()
model = getattr(reduced_models, neuron)(size=1)
if neuron == 'LeakyIntegrator':
runner = bp.DSRunner(model,
@@ -26,12 +28,14 @@ def test_run_shape(self, neuron):
runner.run(10.)
self.assertTupleEqual(runner.mon['V'].shape, (100, 1))
self.assertTupleEqual(runner.mon['spike'].shape, (100, 1))
+ bm.clear_buffer_memory()
@parameterized.named_parameters(
{'testcase_name': f'noise_of_{name}', 'neuron': name}
for name in reduced_models.__all__
)
def test_noise_shape(self, neuron):
+ bm.random.seed()
model = getattr(reduced_models, neuron)(size=1, noise=0.1)
if neuron == 'LeakyIntegrator':
runner = bp.DSRunner(model,
@@ -46,12 +50,14 @@ def test_noise_shape(self, neuron):
runner.run(10.)
self.assertTupleEqual(runner.mon['V'].shape, (100, 1))
self.assertTupleEqual(runner.mon['spike'].shape, (100, 1))
+ bm.clear_buffer_memory()
@parameterized.named_parameters(
{'testcase_name': f'noise_of_{name}', 'neuron': name}
for name in reduced_models.__all__
)
def test_training_shape(self, neuron):
+ bm.random.seed()
if neuron == 'FHN':
model = getattr(reduced_models, neuron)(size=10)
runner = bp.DSRunner(model,
@@ -66,3 +72,4 @@ def test_training_shape(self, neuron):
progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10))
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/synapses/__init__.py b/brainpy/_src/dynold/synapses/__init__.py
similarity index 53%
rename from brainpy/_src/synapses/__init__.py
rename to brainpy/_src/dynold/synapses/__init__.py
index ca2960417..233535ff5 100644
--- a/brainpy/_src/synapses/__init__.py
+++ b/brainpy/_src/dynold/synapses/__init__.py
@@ -1,10 +1,8 @@
# -*- coding: utf-8 -*-
+from .base import *
from .abstract_models import *
from .biological_models import *
from .learning_rules import *
-from .gap_junction import *
-from .delay_couplings import *
+from .compat import *
-# compatible interface
-from . import compat
diff --git a/brainpy/_src/synapses/abstract_models.py b/brainpy/_src/dynold/synapses/abstract_models.py
similarity index 65%
rename from brainpy/_src/synapses/abstract_models.py
rename to brainpy/_src/dynold/synapses/abstract_models.py
index 4f82392db..8366bbe9c 100644
--- a/brainpy/_src/synapses/abstract_models.py
+++ b/brainpy/_src/dynold/synapses/abstract_models.py
@@ -2,17 +2,17 @@
from typing import Union, Dict, Callable, Optional
-from jax import vmap
-from jax.lax import stop_gradient
+import jax
import brainpy.math as bm
from brainpy._src.connect import TwoEndConnector, All2All, One2One
-from brainpy._src.synouts import CUBA, MgBlock
-from brainpy._src.dynsys import NeuGroup, SynOut, SynSTP, TwoEndConn, SynConn
-from brainpy._src.initialize import Initializer, variable_
-from brainpy._src.integrators import odeint, JointEq
-from brainpy.check import is_integer, is_float, is_subclass
+from brainpy._src.dyn import synapses
+from brainpy._src.dynold.synouts import MgBlock, CUBA
+from brainpy._src.dynsys import NeuDyn
+from brainpy._src.initialize import Initializer
+from brainpy._src.mixin import AlignPost
from brainpy.types import ArrayType
+from .base import TwoEndConn, _SynSTP, _SynOut, _TwoEndConnAlignPre, _TwoEndConnAlignPost, _DelayedSyn, _init_stp
__all__ = [
'Delta',
@@ -20,7 +20,6 @@
'DualExponential',
'Alpha',
'NMDA',
- 'PoissonInput',
]
@@ -67,9 +66,9 @@ class Delta(TwoEndConn):
Parameters
----------
- pre: NeuGroup
+ pre: NeuDyn
The pre-synaptic neuron group.
- post: NeuGroup
+ post: NeuDyn
The post-synaptic neuron group.
conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
@@ -86,17 +85,15 @@ class Delta(TwoEndConn):
def __init__(
self,
- pre: NeuGroup,
- post: NeuGroup,
+ pre: NeuDyn,
+ post: NeuDyn,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- output: SynOut = CUBA(target_var='V'),
- stp: Optional[SynSTP] = None,
+ output: _SynOut = CUBA(target_var='V'),
+ stp: Optional[_SynSTP] = None,
comp_method: str = 'sparse',
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
delay_step: Union[float, ArrayType, Initializer, Callable] = None,
post_ref_key: str = None,
-
- # other parameters
name: str = None,
mode: bm.Mode = None,
stop_spike_gradient: bool = False,
@@ -127,17 +124,17 @@ def reset_state(self, batch_size=None):
if self.stp is not None:
self.stp.reset_state(batch_size)
- def update(self, tdi, pre_spike=None):
+ def update(self, pre_spike=None):
# pre-synaptic spikes
if pre_spike is None:
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", delay_step=self.delay_step)
pre_spike = bm.as_jax(pre_spike)
if self.stop_spike_gradient:
- pre_spike = stop_gradient(pre_spike)
+ pre_spike = jax.lax.stop_gradient(pre_spike)
# update sub-components
- self.output.update(tdi)
- if self.stp is not None: self.stp.update(tdi, pre_spike)
+ if self.stp is not None:
+ self.stp.update(pre_spike)
# synaptic values onto the post
if isinstance(self.conn, All2All):
@@ -152,18 +149,20 @@ def update(self, tdi, pre_spike=None):
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
- f = lambda s: bm.event.csrmv(
- self.g_max, self.conn_mask[0], self.conn_mask[1], s,
- shape=(self.pre.num, self.post.num), transpose=True
- )
- if isinstance(self.mode, bm.BatchingMode): f = vmap(f)
- post_vs = f(pre_spike)
- # if not isinstance(self.stp, _NullSynSTP):
- # raise NotImplementedError()
- # stp_value = self.stp(1.)
- # f2 = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
- # if self.trainable: f2 = vmap(f2)
- # post_vs *= f2(stp_value)
+ if self.stp is not None:
+ syn_value = self.stp(pre_spike)
+ f = lambda s: bm.sparse.csrmv(
+ self.g_max, self.conn_mask[0], self.conn_mask[1], s,
+ shape=(self.pre.num, self.post.num), transpose=True
+ )
+ else:
+ syn_value = pre_spike
+ f = lambda s: bm.event.csrmv(
+ self.g_max, self.conn_mask[0], self.conn_mask[1], s,
+ shape=(self.pre.num, self.post.num), transpose=True
+ )
+ if isinstance(self.mode, bm.BatchingMode): f = jax.vmap(f)
+ post_vs = f(syn_value)
else:
syn_value = bm.asarray(pre_spike, dtype=bm.float_)
if self.stp is not None:
@@ -176,7 +175,7 @@ def update(self, tdi, pre_spike=None):
return self.output(post_vs)
-class Exponential(TwoEndConn):
+class Exponential(_TwoEndConnAlignPost, AlignPost):
r"""Exponential decay synapse model.
**Model Descriptions**
@@ -242,9 +241,9 @@ class Exponential(TwoEndConn):
Parameters
----------
- pre: NeuGroup
+ pre: NeuDyn
The pre-synaptic neuron group.
- post: NeuGroup
+ post: NeuDyn
The post-synaptic neuron group.
conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
@@ -273,29 +272,20 @@ class Exponential(TwoEndConn):
def __init__(
self,
- pre: NeuGroup,
- post: NeuGroup,
+ pre: NeuDyn,
+ post: NeuDyn,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- output: Optional[SynOut] = CUBA(),
- stp: Optional[SynSTP] = None,
+ output: Optional[_SynOut] = CUBA(),
+ stp: Optional[_SynSTP] = None,
comp_method: str = 'sparse',
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
delay_step: Union[int, ArrayType, Initializer, Callable] = None,
tau: Union[float, ArrayType] = 8.0,
method: str = 'exp_auto',
-
- # other parameters
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
stop_spike_gradient: bool = False,
):
- super(Exponential, self).__init__(pre=pre,
- post=post,
- conn=conn,
- output=output,
- stp=stp,
- name=name,
- mode=mode)
# parameters
self.stop_spike_gradient = stop_spike_gradient
self.comp_method = comp_method
@@ -303,67 +293,50 @@ def __init__(
if bm.size(self.tau) != 1:
raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}')
- # connections and weights
- self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='csr')
-
- # variables
- self.g = variable_(bm.zeros, self.post.num, self.mode)
- self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
-
- # function
- self.integral = odeint(lambda g, t: -g / self.tau, method=method)
-
- def reset_state(self, batch_size=None):
- self.g.value = variable_(bm.zeros, self.post.num, batch_size)
- self.output.reset_state(batch_size)
- if self.stp is not None: self.stp.reset_state(batch_size)
-
- def update(self, tdi, pre_spike=None):
- t, dt = tdi['t'], tdi['dt']
-
- # delays
- if pre_spike is None:
- pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
- pre_spike = bm.as_jax(pre_spike)
- if self.stop_spike_gradient:
- pre_spike = stop_gradient(pre_spike)
-
- # update sub-components
- self.output.update(tdi)
- if self.stp is not None: self.stp.update(tdi, pre_spike)
-
- # post values
- if isinstance(self.conn, All2All):
- syn_value = bm.asarray(pre_spike, dtype=bm.float_)
- if self.stp is not None: syn_value = self.stp(syn_value)
- post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
- elif isinstance(self.conn, One2One):
- syn_value = bm.asarray(pre_spike, dtype=bm.float_)
- if self.stp is not None: syn_value = self.stp(syn_value)
- post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
- else:
- if self.comp_method == 'sparse':
- f = lambda s: bm.event.csrmv(
- self.g_max, self.conn_mask[0], self.conn_mask[1], s,
- shape=(self.pre.num, self.post.num),
- transpose=True
- )
- if isinstance(self.mode, bm.BatchingMode): f = vmap(f)
- post_vs = f(pre_spike)
- # if not isinstance(self.stp, _NullSynSTP):
- # raise NotImplementedError()
- else:
- syn_value = bm.asarray(pre_spike, dtype=bm.float_)
- if self.stp is not None: syn_value = self.stp(syn_value)
- post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
- # updates
- self.g.value = self.integral(self.g.value, t, dt) + post_vs
-
- # output
- return self.output(self.g)
-
-
-class DualExponential(TwoEndConn):
+ syn = synapses.Expon.desc(pre.size,
+ pre.keep_size,
+ mode=mode,
+ tau=tau,
+ method=method)
+
+ super().__init__(pre=pre,
+ post=post,
+ syn=syn,
+ conn=conn,
+ output=output,
+ stp=stp,
+ comp_method=comp_method,
+ g_max=g_max,
+ delay_step=delay_step,
+ name=name,
+ mode=mode)
+
+ # copy the references
+ syn = self.post.before_updates[self.proj._post_repr].syn
+ self.g = syn.g
+
+ def update(self, pre_spike=None):
+ return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)
+
+ def add_current(self, input):
+ self.g += input
+
+
+class _DelayedDualExp(_DelayedSyn):
+ not_desc_params = ('master', 'stp', 'mode')
+
+ def __init__(self, size, keep_size, mode, tau_decay, tau_rise, method, master, stp=None):
+ syn = synapses.DualExpon(size,
+ keep_size,
+ mode=mode,
+ tau_decay=tau_decay,
+ tau_rise=tau_rise,
+ method=method)
+ stp = _init_stp(stp, master)
+ super().__init__(syn, stp)
+
+
+class DualExponential(_TwoEndConnAlignPre):
r"""Dual exponential synapse model.
**Model Descriptions**
@@ -425,9 +398,9 @@ class DualExponential(TwoEndConn):
Parameters
----------
- pre: NeuGroup
+ pre: NeuDyn
The pre-synaptic neuron group.
- post: NeuGroup
+ post: NeuDyn
The post-synaptic neuron group.
conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
@@ -460,11 +433,11 @@ class DualExponential(TwoEndConn):
def __init__(
self,
- pre: NeuGroup,
- post: NeuGroup,
+ pre: NeuDyn,
+ post: NeuDyn,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- stp: Optional[SynSTP] = None,
- output: SynOut = CUBA(),
+ stp: Optional[_SynSTP] = None,
+ output: _SynOut = None, # CUBA(),
comp_method: str = 'dense',
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
tau_decay: Union[float, ArrayType] = 10.0,
@@ -477,16 +450,8 @@ def __init__(
mode: bm.Mode = None,
stop_spike_gradient: bool = False,
):
- super(DualExponential, self).__init__(pre=pre,
- post=post,
- conn=conn,
- output=output,
- stp=stp,
- name=name,
- mode=mode)
+
# parameters
- # self.check_pre_attrs('spike')
- self.check_post_attrs('input')
self.stop_spike_gradient = stop_spike_gradient
self.comp_method = comp_method
self.tau_rise = tau_rise
@@ -498,68 +463,35 @@ def __init__(
raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. '
f'But we got {self.tau_decay}')
- # connections
- self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='csr')
-
- # variables
- self.h = variable_(bm.zeros, self.pre.num, self.mode)
- self.g = variable_(bm.zeros, self.pre.num, self.mode)
- self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
-
- # integral
- self.integral = odeint(method=method, f=JointEq([self.dg, self.dh]))
-
- def reset_state(self, batch_size=None):
- self.h.value = variable_(bm.zeros, self.pre.num, batch_size)
- self.g.value = variable_(bm.zeros, self.pre.num, batch_size)
- self.output.reset_state(batch_size)
- if self.stp is not None: self.stp.reset_state(batch_size)
-
- def dh(self, h, t):
- return -h / self.tau_rise
-
- def dg(self, g, t, h):
- return -g / self.tau_decay + h
-
- def update(self, tdi, pre_spike=None):
- t, dt = tdi['t'], tdi['dt']
-
- # pre-synaptic spikes
- if pre_spike is None:
- pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
- pre_spike = bm.as_jax(pre_spike)
- if self.stop_spike_gradient:
- pre_spike = stop_gradient(pre_spike)
-
- # update sub-components
- self.output.update(tdi)
- if self.stp is not None: self.stp.update(tdi, pre_spike)
-
- # update synaptic variables
- self.g.value, self.h.value = self.integral(self.g, self.h, t, dt)
- self.h += pre_spike
+ syn = _DelayedDualExp.desc(pre.size,
+ pre.keep_size,
+ mode=mode,
+ tau_decay=tau_decay,
+ tau_rise=tau_rise,
+ method=method,
+ stp=stp,
+ master=self)
+
+ super().__init__(pre=pre,
+ post=post,
+ syn=syn,
+ conn=conn,
+ output=output,
+ stp=stp,
+ comp_method=comp_method,
+ g_max=g_max,
+ delay_step=delay_step,
+ name=name,
+ mode=mode)
- # post values
- syn_value = self.g.value
- if self.stp is not None: syn_value = self.stp(syn_value)
- if isinstance(self.conn, All2All):
- post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
- elif isinstance(self.conn, One2One):
- post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
- else:
- if self.comp_method == 'sparse':
- f = lambda s: bm.sparse.csrmv(
- self.g_max, self.conn_mask[0], self.conn_mask[1], s,
- shape=(self.pre.num, self.post.num),
- transpose=True
- )
- if isinstance(self.mode, bm.BatchingMode): f = vmap(f)
- post_vs = f(syn_value)
- else:
- post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
+ self.check_post_attrs('input')
+ # copy the references
+ syn = self.pre.after_updates[self.proj._syn_id].syn.syn
+ self.g = syn.g
+ self.h = syn.h
- # output
- return self.output(post_vs)
+ def update(self, pre_spike=None):
+ return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)
class Alpha(DualExponential):
@@ -614,9 +546,9 @@ class Alpha(DualExponential):
Parameters
----------
- pre: NeuGroup
+ pre: NeuDyn
The pre-synaptic neuron group.
- post: NeuGroup
+ post: NeuDyn
The post-synaptic neuron group.
conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
@@ -644,11 +576,11 @@ class Alpha(DualExponential):
def __init__(
self,
- pre: NeuGroup,
- post: NeuGroup,
+ pre: NeuDyn,
+ post: NeuDyn,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- output: SynOut = CUBA(),
- stp: Optional[SynSTP] = None,
+ output: _SynOut = None, # CUBA(),
+ stp: Optional[_SynSTP] = None,
comp_method: str = 'dense',
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
delay_step: Union[int, ArrayType, Initializer, Callable] = None,
@@ -676,7 +608,22 @@ def __init__(
stop_spike_gradient=stop_spike_gradient)
-class NMDA(TwoEndConn):
+class _DelayedNMDA(_DelayedSyn):
+ not_desc_params = ('master', 'stp', 'mode')
+
+ def __init__(self, size, keep_size, mode, a, tau_decay, tau_rise, method, master, stp=None):
+ syn = synapses.NMDA(size,
+ keep_size,
+ mode=mode,
+ a=a,
+ tau_decay=tau_decay,
+ tau_rise=tau_rise,
+ method=method)
+ stp = _init_stp(stp, master)
+ super().__init__(syn, stp)
+
+
+class NMDA(_TwoEndConnAlignPre):
r"""NMDA synapse model.
**Model Descriptions**
@@ -763,9 +710,9 @@ class NMDA(TwoEndConn):
Parameters
----------
- pre: NeuGroup
+ pre: NeuDyn
The pre-synaptic neuron group.
- post: NeuGroup
+ post: NeuDyn
The post-synaptic neuron group.
conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
@@ -805,11 +752,11 @@ class NMDA(TwoEndConn):
def __init__(
self,
- pre: NeuGroup,
- post: NeuGroup,
+ pre: NeuDyn,
+ post: NeuDyn,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- output: SynOut = MgBlock(E=0., alpha=0.062, beta=3.57, cc_Mg=1.2),
- stp: Optional[SynSTP] = None,
+ output: _SynOut = MgBlock(E=0., alpha=0.062, beta=3.57, cc_Mg=1.2),
+ stp: Optional[_SynSTP] = None,
comp_method: str = 'dense',
g_max: Union[float, ArrayType, Initializer, Callable] = 0.15,
delay_step: Union[int, ArrayType, Initializer, Callable] = None,
@@ -817,21 +764,11 @@ def __init__(
a: Union[float, ArrayType] = 0.5,
tau_rise: Union[float, ArrayType] = 2.,
method: str = 'exp_auto',
-
- # other parameters
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
stop_spike_gradient: bool = False,
):
- super(NMDA, self).__init__(pre=pre,
- post=post,
- conn=conn,
- output=output,
- stp=stp,
- name=name,
- mode=mode)
# parameters
- # self.check_post_attrs('input', 'V')
self.tau_decay = tau_decay
self.tau_rise = tau_rise
self.a = a
@@ -844,146 +781,32 @@ def __init__(
self.comp_method = comp_method
self.stop_spike_gradient = stop_spike_gradient
- # connections and weights
- self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='csr')
-
- # variables
- self.g = variable_(bm.zeros, self.pre.num, self.mode)
- self.x = variable_(bm.zeros, self.pre.num, self.mode)
- self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
-
- # integral
- self.integral = odeint(method=method, f=JointEq(self.dg, self.dx))
-
- def dg(self, g, t, x):
- return -g / self.tau_decay + self.a * x * (1 - g)
-
- def dx(self, x, t):
- return -x / self.tau_rise
-
- def reset_state(self, batch_size=None):
- self.g.value = variable_(bm.zeros, self.pre.num, batch_size)
- self.x.value = variable_(bm.zeros, self.pre.num, batch_size)
- self.output.reset_state(batch_size)
- if self.stp is not None: self.stp.reset_state(batch_size)
-
- def update(self, tdi, pre_spike=None):
- t, dt = tdi['t'], tdi['dt']
- # delays
- if pre_spike is None:
- pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
- pre_spike = bm.as_jax(pre_spike)
- if self.stop_spike_gradient:
- pre_spike = stop_gradient(pre_spike)
-
- # update sub-components
- self.output.update(tdi)
- if self.stp is not None: self.stp.update(tdi, pre_spike)
-
- # update synapse variables
- self.g.value, self.x.value = self.integral(self.g, self.x, t, dt=dt)
- self.x += pre_spike
-
- # post-synaptic value
- syn_value = self.g.value
- if self.stp is not None: syn_value = self.stp(syn_value)
- if isinstance(self.conn, All2All):
- post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
- elif isinstance(self.conn, One2One):
- post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
- else:
- if self.comp_method == 'sparse':
- f = lambda s: bm.event.csrmv(
- self.g_max, self.conn_mask[0], self.conn_mask[1], s,
- shape=(self.pre.num, self.post.num),
- transpose=True
- )
- if isinstance(self.mode, bm.BatchingMode): f = vmap(f)
- post_vs = f(syn_value)
- else:
- post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
-
- # output
- return self.output(post_vs)
-
-
-class PoissonInput(SynConn):
- """Poisson Input to the given `Variable`.
-
- Adds independent Poisson input to a target variable. For large
- numbers of inputs, this is much more efficient than creating a
- `PoissonGroup`. The synaptic events are generated randomly during the
- simulation and are not preloaded and stored in memory. All the inputs must
- target the same variable, have the same frequency and same synaptic weight.
- All neurons in the target variable receive independent realizations of
- Poisson spike trains.
-
- Parameters
- ----------
- target_var: Variable
- The variable that is targeted by this input.
- num_input: int
- The number of inputs.
- freq: float
- The frequency of each of the inputs. Must be a scalar.
- weight: float
- The synaptic weight. Must be a scalar.
- """
-
- def __init__(
- self,
- target_var: bm.Variable,
- num_input: int,
- freq: Union[int, float],
- weight: Union[int, float],
- seed: Optional[int] = None,
- mode: bm.Mode = None,
- name: str = None
- ):
- from ..neurons.input_groups import InputGroup, OutputGroup
- super(PoissonInput, self).__init__(InputGroup(1), OutputGroup(1), name=name, mode=mode)
- self.pre = None
- self.post = None
-
- # check data
- if not isinstance(target_var, bm.Variable):
- raise TypeError(f'"target_var" must be an instance of Variable. '
- f'But we got {type(target_var)}: {target_var}')
- is_integer(num_input, 'num_input', min_bound=1)
- is_float(freq, 'freq', min_bound=0., allow_int=True)
- is_float(weight, 'weight', allow_int=True)
- is_subclass(mode, (bm.NonBatchingMode, bm.BatchingMode), name=self.__class__.__name__)
-
- # parameters
- self.target_var = target_var
- self.num_input = num_input
- self.freq = freq
- self.weight = weight
- self.seed = seed
-
- def update(self, tdi):
- p = self.freq * tdi.dt / 1e3
- a = self.num_input * p
- b = self.num_input * (1 - p)
- if isinstance(tdi.dt, (int, float)): # dt is not in tracing
- if (a > 5) and (b > 5):
- inp = bm.random.normal(a, b * p, self.target_var.shape)
- else:
- inp = bm.random.binomial(self.num_input, p, self.target_var.shape)
-
- else: # dt is in tracing
- inp = bm.cond((a > 5) * (b > 5),
- lambda _: bm.random.normal(a, b * p, self.target_var.shape),
- lambda _: bm.random.binomial(self.num_input, p, self.target_var.shape),
- None)
- self.target_var += inp * self.weight
-
- def __repr__(self):
- names = self.__class__.__name__
- return f'{names}(name={self.name}, num_input={self.num_input}, freq={self.freq}, weight={self.weight})'
-
- def reset_state(self, batch_size=None):
- pass
-
- def reset(self, batch_size=None):
- self.reset_state(batch_size)
+ syn = _DelayedNMDA.desc(pre.size,
+ pre.keep_size,
+ mode=mode,
+ a=a,
+ tau_decay=tau_decay,
+ tau_rise=tau_rise,
+ method=method,
+ stp=stp,
+ master=self)
+
+ super().__init__(pre=pre,
+ post=post,
+ syn=syn,
+ conn=conn,
+ output=output,
+ stp=stp,
+ comp_method=comp_method,
+ g_max=g_max,
+ delay_step=delay_step,
+ name=name,
+ mode=mode)
+
+ # copy the references
+ syn = self.pre.after_updates[self.proj._syn_id].syn.syn
+ self.g = syn.g
+ self.x = syn.x
+
+ def update(self, pre_spike=None):
+ return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)
diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py
new file mode 100644
index 000000000..b36b40c9b
--- /dev/null
+++ b/brainpy/_src/dynold/synapses/base.py
@@ -0,0 +1,562 @@
+from typing import Union, Dict, Callable, Optional, Tuple
+
+import jax
+import numpy as np
+
+from brainpy import math as bm
+from brainpy._src.connect import TwoEndConnector, MatConn, IJConn, One2One, All2All
+from brainpy._src.dnn import linear
+from brainpy._src.dyn import projections
+from brainpy._src.dynsys import Projection, DynamicalSystem, NeuDyn, Sequential
+from brainpy._src.initialize import parameter
+from brainpy._src.mixin import (ParamDesc, ParamDescInit, JointType,
+ AutoDelaySupp, BindCondData, AlignPost,
+ ReturnInfo)
+from brainpy.errors import UnsupportedError
+from brainpy.types import ArrayType
+
+__all__ = [
+ 'SynConn',
+ '_SynSTP',
+ '_SynOut',
+ 'TwoEndConn',
+ '_TwoEndConnAlignPre',
+ '_TwoEndConnAlignPost',
+]
+
+
+class SynConn(Projection):
+ """Base class to model two-end synaptic connections.
+
+ Parameters
+ ----------
+ pre : NeuGroup
+ Pre-synaptic neuron group.
+ post : NeuGroup
+ Post-synaptic neuron group.
+ conn : optional, ndarray, ArrayType, dict, TwoEndConnector
+ The connection method between pre- and post-synaptic groups.
+ name : str, optional
+ The name of the dynamic system.
+ """
+
+ def __init__(
+ self,
+ pre: DynamicalSystem,
+ post: DynamicalSystem,
+ conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # pre or post neuron group
+ # ------------------------
+ if not isinstance(pre, DynamicalSystem):
+ raise TypeError('"pre" must be an instance of DynamicalSystem.')
+ if not isinstance(post, DynamicalSystem):
+ raise TypeError('"post" must be an instance of DynamicalSystem.')
+ self.pre = pre
+ self.post = post
+
+ # connectivity
+ # ------------
+ if isinstance(conn, TwoEndConnector):
+ self.conn = conn(pre.size, post.size)
+ elif isinstance(conn, (bm.Array, np.ndarray, jax.Array)):
+ if (pre.num, post.num) != conn.shape:
+ raise ValueError(f'"conn" is provided as a matrix, and it is expected '
+ f'to be an array with shape of (pre.num, post.num) = '
+ f'{(pre.num, post.num)}, however we got {conn.shape}')
+ self.conn = MatConn(conn_mat=conn)
+ elif isinstance(conn, dict):
+ if not ('i' in conn and 'j' in conn):
+ raise ValueError(f'"conn" is provided as a dict, and it is expected to '
+ f'be a dictionary with "i" and "j" specification, '
+ f'however we got {conn}')
+ self.conn = IJConn(i=conn['i'], j=conn['j'])
+ elif isinstance(conn, str):
+ self.conn = conn
+ elif conn is None:
+ self.conn = None
+ else:
+ raise ValueError(f'Unknown "conn" type: {conn}')
+
+ def __repr__(self):
+ names = self.__class__.__name__
+ return (f'{names}(name={self.name}, mode={self.mode}, \n'
+ f'{" " * len(names)} pre={self.pre}, \n'
+ f'{" " * len(names)} post={self.post})')
+
+ def check_pre_attrs(self, *attrs):
+ """Check whether pre group satisfies the requirement."""
+ if not hasattr(self, 'pre'):
+ raise ValueError('Please call __init__ function first.')
+ for attr in attrs:
+ if not isinstance(attr, str):
+ raise TypeError(f'Must be string. But got {attr}.')
+ if not hasattr(self.pre, attr):
+ raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".')
+
+ def check_post_attrs(self, *attrs):
+ """Check whether post group satisfies the requirement."""
+ if not hasattr(self, 'post'):
+ raise ValueError('Please call __init__ function first.')
+ for attr in attrs:
+ if not isinstance(attr, str):
+ raise TypeError(f'Must be string. But got {attr}.')
+ if not hasattr(self.post, attr):
+ raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".')
+
+ def update(self, *args, **kwargs):
+ """The function to specify the updating rule.
+
+ Assume any dynamical system depends on the shared variables (`sha`),
+ like time variable ``t``, the step precision ``dt``, and the time step `i`.
+ """
+ raise NotImplementedError('Must implement "update" function by subclass self.')
+
+
+class _SynapseComponent(DynamicalSystem):
+ """Base class for modeling synaptic components,
+ including synaptic output, synaptic short-term plasticity,
+ synaptic long-term plasticity, and others. """
+
+ '''Master of this component.'''
+ master: SynConn
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self._registered = False
+
+ @property
+ def isregistered(self) -> bool:
+ """State of the component, representing whether it has been registered."""
+ return self._registered
+
+ @isregistered.setter
+ def isregistered(self, val: bool):
+ if not isinstance(val, bool):
+ raise ValueError('Must be an instance of bool.')
+ self._registered = val
+
+ def reset_state(self, batch_size=None):
+ pass
+
+ def register_master(self, master: SynConn):
+ if not isinstance(master, SynConn):
+ raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}')
+ if self.isregistered:
+ raise ValueError(f'master has been registered, but we got another master going to be registered.')
+ if hasattr(self, 'master') and self.master != master:
+ raise ValueError(f'master has been registered, but we got another master going to be registered.')
+ self.master = master
+ self._registered = True
+
+ def __repr__(self):
+ return self.__class__.__name__
+
+ def __call__(self, *args, **kwargs):
+ return self.filter(*args, **kwargs)
+
+ def clone(self) -> '_SynapseComponent':
+ """The function useful to clone a new object when it has been used."""
+ raise NotImplementedError
+
+ def filter(self, g):
+ raise NotImplementedError
+
+
+class _SynOut(_SynapseComponent, ParamDesc):
+ """Base class for synaptic current output."""
+
+ def __init__(
+ self,
+ name: str = None,
+ target_var: Union[str, bm.Variable] = None,
+ ):
+ super().__init__(name=name)
+ # check target variable
+ if target_var is not None:
+ if not isinstance(target_var, (str, bm.Variable)):
+ raise TypeError('"target_var" must be instance of string or Variable. '
+ f'But we got {type(target_var)}')
+ self.target_var: Optional[bm.Variable] = target_var
+
+ def register_master(self, master: SynConn):
+ super().register_master(master)
+
+ # initialize target variable to output
+ if isinstance(self.target_var, str):
+ if not hasattr(self.master.post, self.target_var):
+ raise KeyError(f'Post-synaptic group does not have target variable: {self.target_var}')
+ self.target_var = getattr(self.master.post, self.target_var)
+
+ def filter(self, g):
+ if self.target_var is None:
+ return g
+ else:
+ self.target_var += g
+
+ def update(self):
+ pass
+
+
+class _SynSTP(_SynapseComponent, ParamDesc, AutoDelaySupp):
+ """Base class for synaptic short-term plasticity."""
+
+ def update(self, pre_spike):
+ pass
+
+ def return_info(self):
+ assert self.isregistered
+ return ReturnInfo(self.master.pre.varshape, None, self.master.pre.mode, init=bm.zeros)
+
+
+class _NullSynOut(_SynOut):
+ def clone(self):
+ return _NullSynOut()
+
+
+class TwoEndConn(SynConn):
+ """Base class to model synaptic connections.
+
+ Parameters
+ ----------
+ pre : NeuGroup
+ Pre-synaptic neuron group.
+ post : NeuGroup
+ Post-synaptic neuron group.
+ conn : optional, ndarray, ArrayType, dict, TwoEndConnector
+ The connection method between pre- and post-synaptic groups.
+ output: Optional, SynOutput
+ The output for the synaptic current.
+
+ .. versionadded:: 2.1.13
+ The output component for a two-end connection model.
+
+ stp: Optional, SynSTP
+ The short-term plasticity model for the synaptic variables.
+
+ .. versionadded:: 2.1.13
+ The short-term plasticity component for a two-end connection model.
+
+ ltp: Optional, SynLTP
+ The long-term plasticity model for the synaptic variables.
+
+ .. versionadded:: 2.1.13
+ The long-term plasticity component for a two-end connection model.
+
+ name: Optional, str
+ The name of the dynamic system.
+ """
+
+ def __init__(
+ self,
+ pre: DynamicalSystem,
+ post: DynamicalSystem,
+ conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]] = None,
+ output: _SynOut = _NullSynOut(),
+ stp: Optional[_SynSTP] = None,
+ ltp: Optional = None,
+ mode: bm.Mode = None,
+ name: str = None,
+ init_stp: bool = True
+ ):
+ super().__init__(pre=pre,
+ post=post,
+ conn=conn,
+ name=name,
+ mode=mode)
+
+ # synaptic output
+ output = _NullSynOut() if output is None else output
+ if output.isregistered:
+ output = output.clone()
+ if not isinstance(output, _SynOut):
+ raise TypeError(f'output must be instance of {_SynOut.__name__}, '
+ f'but we got {type(output)}')
+ output.register_master(master=self)
+ self.output: _SynOut = output
+
+ # short-term synaptic plasticity
+ if init_stp:
+ stp = _init_stp(stp, self)
+ self.stp: Optional[_SynSTP] = stp
+
+ def _init_weights(
+ self,
+ weight: Union[float, ArrayType, Callable],
+ comp_method: str,
+ sparse_data: str = 'csr'
+ ) -> Tuple[Union[float, ArrayType], ArrayType]:
+ if comp_method not in ['sparse', 'dense']:
+ raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}')
+ if sparse_data not in ['csr', 'ij', 'coo']:
+ raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {sparse_data}')
+ if self.conn is None:
+ raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
+
+ # connections and weights
+ if isinstance(self.conn, One2One):
+ weight = parameter(weight, (self.pre.num,), allow_none=False)
+ conn_mask = None
+
+ elif isinstance(self.conn, All2All):
+ weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False)
+ conn_mask = None
+
+ else:
+ if comp_method == 'sparse':
+ if sparse_data == 'csr':
+ conn_mask = self.conn.require('pre2post')
+ elif sparse_data in ['ij', 'coo']:
+ conn_mask = self.conn.require('post_ids', 'pre_ids')
+ else:
+ ValueError(f'Unknown sparse data type: {sparse_data}')
+ weight = parameter(weight, conn_mask[0].shape, allow_none=False)
+ elif comp_method == 'dense':
+ weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False)
+ conn_mask = self.conn.require('conn_mat')
+ else:
+ raise ValueError(f'Unknown connection type: {comp_method}')
+
+ # training weights
+ if isinstance(self.mode, bm.TrainingMode):
+ weight = bm.TrainVar(weight)
+ return weight, conn_mask
+
+ def _syn2post_with_all2all(self, syn_value, syn_weight):
+ if bm.ndim(syn_weight) == 0:
+ if isinstance(self.mode, bm.BatchingMode):
+ post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:])
+ else:
+ post_vs = bm.sum(syn_value)
+ if not self.conn.include_self:
+ post_vs = post_vs - syn_value
+ post_vs = syn_weight * post_vs
+ else:
+ post_vs = syn_value @ syn_weight
+ return post_vs
+
+ def _syn2post_with_one2one(self, syn_value, syn_weight):
+ return syn_value * syn_weight
+
+ def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
+ if bm.ndim(syn_weight) == 0:
+ post_vs = (syn_weight * syn_value) @ conn_mat
+ else:
+ post_vs = syn_value @ (syn_weight * conn_mat)
+ return post_vs
+
+
+def _init_stp(stp, master):
+ if stp is not None:
+ if stp.isregistered:
+ stp = stp.clone()
+ if not isinstance(stp, _SynSTP):
+ raise TypeError(f'Short-term plasticity must be instance of {_SynSTP.__name__}, '
+ f'but we got {type(stp)}')
+ stp.register_master(master=master)
+ return stp
+
+
+def _get_delay(delay_step):
+ if delay_step is None:
+ return None
+ elif callable(delay_step):
+ raise UnsupportedError('Currently delay step supports integer.')
+ else:
+ return delay_step * bm.get_dt()
+
+
+class _TempOut(DynamicalSystem, BindCondData, ParamDesc):
+ def update(self, *args, **kwargs):
+ pass
+
+
+class _TwoEndConnAlignPre(TwoEndConn):
+ def __init__(
+ self,
+ pre: NeuDyn,
+ post: NeuDyn,
+ syn: ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]],
+ conn: TwoEndConnector,
+ g_max: Union[float, ArrayType, Callable],
+ output: JointType[DynamicalSystem, BindCondData] = _NullSynOut(),
+ stp: Optional[_SynSTP] = None,
+ comp_method: str = 'dense',
+ delay_step: Union[int, ArrayType, Callable] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ assert isinstance(pre, NeuDyn)
+ assert isinstance(post, NeuDyn)
+ assert isinstance(syn, ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]])
+
+ super().__init__(pre=pre,
+ post=post,
+ conn=conn,
+ output=output,
+ stp=None,
+ name=name,
+ mode=mode,
+ init_stp=False)
+
+ delay = _get_delay(delay_step)
+
+ # Projection
+ if isinstance(conn, All2All):
+ proj = projections.ProjAlignPre(pre=pre,
+ syn=syn,
+ delay=delay,
+ comm=linear.AllToAll(pre.num, post.num, g_max),
+ out=_TempOut(),
+ post=post)
+
+ elif isinstance(conn, One2One):
+ assert post.num == pre.num
+ proj = projections.ProjAlignPre(pre=pre,
+ syn=syn,
+ delay=delay,
+ comm=linear.OneToOne(pre.num, g_max),
+ out=_TempOut(),
+ post=post)
+
+ else:
+ if comp_method == 'dense':
+ proj = projections.ProjAlignPre(pre=pre,
+ syn=syn,
+ delay=delay,
+ comm=linear.MaskedLinear(conn, g_max),
+ out=_TempOut(),
+ post=post)
+
+ elif comp_method == 'sparse':
+ proj = projections.ProjAlignPre(pre=pre,
+ syn=syn,
+ delay=delay,
+ comm=linear.CSRLinear(conn, g_max),
+ out=_TempOut(),
+ post=post)
+
+ else:
+ raise UnsupportedError(f'Does not support {comp_method}, only "sparse" or "dense".')
+ self.proj = proj
+ self.proj.post.cur_inputs.pop(self.proj.name)
+ self.stp = self.pre.after_updates[self.proj._syn_id].syn.stp
+
+ def update(self, pre_spike=None, stop_spike_gradient: bool = False):
+ if pre_spike is None:
+ pre_spike = self.pre.after_updates[self.proj._syn_id].delay.at(self.proj.name)
+ if stop_spike_gradient:
+ pre_spike = jax.lax.stop_gradient(pre_spike)
+ current = self.proj.comm(pre_spike)
+ return self.output(current)
+
+
+class _TwoEndConnAlignPost(TwoEndConn):
+ def __init__(
+ self,
+ pre: NeuDyn,
+ post: NeuDyn,
+ syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]],
+ conn: TwoEndConnector,
+ g_max: Union[float, ArrayType, Callable],
+ output: _SynOut = _NullSynOut(),
+ stp: Optional[_SynSTP] = None,
+ comp_method: str = 'dense',
+ delay_step: Union[int, ArrayType, Callable] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(pre=pre,
+ post=post,
+ conn=conn,
+ output=output,
+ stp=stp,
+ name=name,
+ mode=mode,
+ init_stp=True)
+
+ pre = _DelayedSyn(pre, self.stp)
+ delay = _get_delay(delay_step)
+
+ # make every synapse unique
+ syn._identifier = syn._identifier + f' // {self.name}'
+
+ # Projection
+ if isinstance(conn, All2All):
+ proj = projections.ProjAlignPost(pre=pre,
+ delay=delay,
+ comm=linear.AllToAll(self.pre.num, self.post.num, g_max),
+ syn=syn,
+ out=_TempOut.desc(),
+ post=post)
+
+ elif isinstance(conn, One2One):
+ assert post.num == self.pre.num
+ proj = projections.ProjAlignPost(pre=pre,
+ delay=delay,
+ comm=linear.OneToOne(self.pre.num, g_max),
+ syn=syn,
+ out=_TempOut.desc(),
+ post=post)
+
+ else:
+ if comp_method == 'dense':
+ proj = projections.ProjAlignPost(pre=pre,
+ delay=delay,
+ comm=linear.MaskedLinear(conn, g_max),
+ syn=syn,
+ out=_TempOut.desc(),
+ post=post)
+
+ elif comp_method == 'sparse':
+ if self.stp is None:
+ comm = linear.EventCSRLinear(conn, g_max)
+ else:
+ comm = linear.CSRLinear(conn, g_max)
+ proj = projections.ProjAlignPost(pre=pre,
+ delay=delay,
+ comm=comm,
+ syn=syn,
+ out=_TempOut.desc(),
+ post=post)
+
+ else:
+ raise UnsupportedError(f'Does not support {comp_method}, only "sparse" or "dense".')
+ self.proj = proj
+ self.proj.post.cur_inputs.pop(self.proj.name)
+
+ def update(self, pre_spike=None, stop_spike_gradient: bool = False):
+ if pre_spike is None:
+ pre_spike = self.proj.pre.after_updates[self.proj._delay_repr].at(self.proj.name)
+ if stop_spike_gradient:
+ # TODO: if self.stp is not None
+ pre_spike = jax.lax.stop_gradient(pre_spike)
+ current = self.proj.comm(pre_spike)
+ self.proj.post.before_updates[self.proj._post_repr].syn.add_current(current) # synapse post current
+ return self.output(current)
+
+
+class _DelayedSyn(DynamicalSystem, ParamDesc, AutoDelaySupp):
+ def __init__(self, syn, stp=None):
+ super().__init__()
+ self.syn = syn
+ self.stp = stp
+
+ def update(self, x):
+ if self.stp is None:
+ return self.syn(x)
+ else:
+ self.stp.update(x)
+ return self.stp(self.syn(x))
+
+ def return_info(self):
+ if self.stp is None:
+ return self.syn.return_info()
+ else:
+ return self.stp.return_info()
+
diff --git a/brainpy/_src/dynold/synapses/biological_models.py b/brainpy/_src/dynold/synapses/biological_models.py
new file mode 100644
index 000000000..861db52e9
--- /dev/null
+++ b/brainpy/_src/dynold/synapses/biological_models.py
@@ -0,0 +1,414 @@
+# -*- coding: utf-8 -*-
+
+from typing import Union, Dict, Callable, Optional
+
+import brainpy.math as bm
+from brainpy._src.connect import TwoEndConnector
+from brainpy._src.dyn import synapses
+from brainpy._src.dynold.synapses import _SynSTP, _SynOut, _TwoEndConnAlignPre
+from brainpy._src.dynold.synapses.base import _init_stp, _DelayedSyn
+from brainpy._src.dynold.synouts import COBA, MgBlock
+from brainpy._src.dynsys import NeuDyn
+from brainpy.types import ArrayType
+
+__all__ = [
+ 'AMPA',
+ 'GABAa',
+ 'BioNMDA',
+]
+
+
+class _DelayedAMPA(_DelayedSyn):
+ not_desc_params = ('master', 'stp', 'mode')
+
+ def __init__(self, size, keep_size, mode, alpha, beta, T, T_dur, method, master, stp=None):
+ syn = synapses.AMPA(size,
+ keep_size,
+ mode=mode,
+ alpha=alpha,
+ beta=beta,
+ T=T,
+ T_dur=T_dur,
+ method=method)
+ stp = _init_stp(stp, master)
+ super().__init__(syn, stp)
+
+
+class AMPA(_TwoEndConnAlignPre):
+ def __init__(
+ self,
+ pre: NeuDyn,
+ post: NeuDyn,
+ conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
+ output: _SynOut = COBA(E=0.),
+ stp: Optional[_SynSTP] = None,
+ comp_method: str = 'dense',
+ g_max: Union[float, ArrayType, Callable] = 0.42,
+ delay_step: Union[int, ArrayType, Callable] = None,
+ alpha: float = 0.98,
+ beta: float = 0.18,
+ T: float = 0.5,
+ T_duration: float = 0.5,
+ method: str = 'exp_auto',
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ stop_spike_gradient: bool = False,
+ ):
+ # parameters
+ self.stop_spike_gradient = stop_spike_gradient
+ self.comp_method = comp_method
+ self.alpha = alpha
+ self.beta = beta
+ self.T = T
+ self.T_duration = T_duration
+ if bm.size(alpha) != 1:
+ raise ValueError(f'"alpha" must be a scalar or a tensor with size of 1. But we got {alpha}')
+ if bm.size(beta) != 1:
+ raise ValueError(f'"beta" must be a scalar or a tensor with size of 1. But we got {beta}')
+ if bm.size(T) != 1:
+ raise ValueError(f'"T" must be a scalar or a tensor with size of 1. But we got {T}')
+ if bm.size(T_duration) != 1:
+ raise ValueError(f'"T_duration" must be a scalar or a tensor with size of 1. But we got {T_duration}')
+
+ # AMPA
+ syn = _DelayedAMPA.desc(
+ pre.size, pre.keep_size, mode=mode, alpha=alpha, beta=beta,
+ T=T, T_dur=T_duration, method=method, stp=stp, master=self,
+ )
+
+ super().__init__(pre=pre,
+ post=post,
+ syn=syn,
+ conn=conn,
+ output=output,
+ stp=stp,
+ comp_method=comp_method,
+ g_max=g_max,
+ delay_step=delay_step,
+ name=name,
+ mode=mode)
+
+ # copy the references
+ syn = self.pre.after_updates[self.proj._syn_id].syn.syn
+ self.g = syn.g
+ self.spike_arrival_time = syn.spike_arrival_time
+
+ def update(self, pre_spike=None):
+ return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)
+
+
+class GABAa(AMPA):
+ r"""GABAa synapse model.
+
+ **Model Descriptions**
+
+ GABAa synapse model has the same equation with the `AMPA synapse <./brainmodels.synapses.AMPA.rst>`_,
+
+ .. math::
+
+ \frac{d g}{d t}&=\alpha[T](1-g) - \beta g \\
+ I_{syn}&= - g_{max} g (V - E)
+
+ but with the difference of:
+
+ - Reversal potential of synapse :math:`E` is usually low, typically -80. mV
+ - Activating rate constant :math:`\alpha=0.53`
+ - De-activating rate constant :math:`\beta=0.18`
+ - Transmitter concentration :math:`[T]=1\,\mu ho(\mu S)` when synapse is
+ triggered by a pre-synaptic spike, with the duration of 1. ms.
+
+ **Model Examples**
+
+ - `Gamma oscillation network model `_
+
+
+ Parameters
+ ----------
+ pre: NeuDyn
+ The pre-synaptic neuron group.
+ post: NeuDyn
+ The post-synaptic neuron group.
+ conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
+ The synaptic connections.
+ comp_method: str
+ The connection type used for model speed optimization. It can be
+ `sparse` and `dense`. The default is `dense`.
+ delay_step: int, ArrayType, Callable
+ The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
+ g_max: float, ArrayType, Callable
+ The synaptic strength (the maximum conductance). Default is 1.
+ alpha: float, ArrayType
+ Binding constant. Default 0.062
+ beta: float, ArrayType
+ Unbinding constant. Default 3.57
+ T: float, ArrayType
+ Transmitter concentration when synapse is triggered by
+ a pre-synaptic spike.. Default 1 [mM].
+ T_duration: float, ArrayType
+ Transmitter concentration duration time after being triggered. Default 1 [ms]
+ name: str
+ The name of this synaptic projection.
+ method: str
+ The numerical integration methods.
+
+ References
+ ----------
+ .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity
+ on the integrative properties of neocortical pyramidal neurons
+ in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547.
+ """
+
+ def __init__(
+ self,
+ pre: NeuDyn,
+ post: NeuDyn,
+ conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
+ output: _SynOut = COBA(E=-80.),
+ stp: Optional[_SynSTP] = None,
+ comp_method: str = 'dense',
+ g_max: Union[float, ArrayType, Callable] = 0.04,
+ delay_step: Union[int, ArrayType, Callable] = None,
+ alpha: Union[float, ArrayType] = 0.53,
+ beta: Union[float, ArrayType] = 0.18,
+ T: Union[float, ArrayType] = 1.,
+ T_duration: Union[float, ArrayType] = 1.,
+ method: str = 'exp_auto',
+
+ # other parameters
+ name: str = None,
+ mode: bm.Mode = None,
+ stop_spike_gradient: bool = False,
+ ):
+ super(GABAa, self).__init__(pre=pre,
+ post=post,
+ conn=conn,
+ output=output,
+ stp=stp,
+ comp_method=comp_method,
+ delay_step=delay_step,
+ g_max=g_max,
+ alpha=alpha,
+ beta=beta,
+ T=T,
+ T_duration=T_duration,
+ method=method,
+ name=name,
+ mode=mode,
+ stop_spike_gradient=stop_spike_gradient, )
+
+
+class _DelayedNMDA(_DelayedSyn):
+ not_desc_params = ('master', 'stp', 'mode')
+
+ def __init__(self, size, keep_size, alpha1, beta1, alpha2, beta2, T, T_dur, method, mode, master, stp=None):
+ syn = synapses.BioNMDA(size,
+ keep_size,
+ mode=mode,
+ alpha1=alpha1,
+ beta1=beta1,
+ alpha2=alpha2,
+ beta2=beta2,
+ T=T,
+ T_dur=T_dur,
+ method=method)
+ stp = _init_stp(stp, master)
+ super().__init__(syn, stp)
+
+
+class BioNMDA(_TwoEndConnAlignPre):
+ r"""Biological NMDA synapse model.
+
+ **Model Descriptions**
+
+ The NMDA receptor is a glutamate receptor and ion channel found in neurons.
+ The NMDA receptor is one of three types of ionotropic glutamate receptors,
+ the other two being AMPA and kainate receptors.
+
+ The NMDA receptor mediated conductance depends on the postsynaptic voltage.
+ The voltage dependence is due to the blocking of the pore of the NMDA receptor
+ from the outside by a positively charged magnesium ion. The channel is
+ nearly completely blocked at resting potential, but the magnesium block is
+ relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}`
+ that are not blocked by magnesium can be fitted to
+
+ .. math::
+
+ g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-a V}
+ \frac{[{Mg}^{2+}]_{o}} {b})^{-1}
+
+ Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration,
+ usually 1 mM. Thus, the channel acts as a
+ "coincidence detector" and only once both of these conditions are met, the
+ channel opens and it allows positively charged ions (cations) to flow through
+ the cell membrane [2]_.
+
+ If we make the approximation that the magnesium block changes
+ instantaneously with voltage and is independent of the gating of the channel,
+ the net NMDA receptor-mediated synaptic current is given by
+
+ .. math::
+
+ I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty}
+
+ where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the
+ reversal potential.
+
+ Simultaneously, the kinetics of synaptic state :math:`g` is determined by a 2nd-order kinetics [1]_:
+
+ .. math::
+
+ & g_\mathrm{NMDA} (t) = g_{max} g \\
+ & \frac{d g}{dt} = \alpha_1 x (1 - g) - \beta_1 g \\
+ & \frac{d x}{dt} = \alpha_2 [T] (1 - x) - \beta_2 x
+
+ where :math:`\alpha_1, \beta_1` refers to the conversion rate of variable g and
+ :math:`\alpha_2, \beta_2` refers to the conversion rate of variable x.
+
+ The NMDA receptor has been thought to be very important for controlling
+ synaptic plasticity and mediating learning and memory functions [3]_.
+
+ .. plot::
+ :include-source: True
+
+ >>> import brainpy as bp
+ >>> from brainpy import neurons, synapses
+ >>> import matplotlib.pyplot as plt
+ >>>
+ >>> neu1 = neurons.HH(1)
+ >>> neu2 = neurons.HH(1)
+ >>> syn1 = synapses.BioNMDA(neu1, neu2, bp.connect.All2All())
+ >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
+ >>>
+ >>> runner = bp.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x'])
+ >>> runner.run(150.)
+ >>>
+ >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
+ >>> fig.add_subplot(gs[0, 0])
+ >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
+ >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
+ >>> plt.legend()
+ >>>
+ >>> fig.add_subplot(gs[1, 0])
+ >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
+ >>> plt.plot(runner.mon.ts, runner.mon['syn.x'], label='x')
+ >>> plt.legend()
+ >>> plt.show()
+
+ Parameters
+ ----------
+ pre: NeuDyn
+ The pre-synaptic neuron group.
+ post: NeuDyn
+ The post-synaptic neuron group.
+ conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
+ The synaptic connections.
+ comp_method: str
+ The connection type used for model speed optimization. It can be
+ `sparse` and `dense`. The default is `dense`.
+ delay_step: int, ArrayType, Callable
+ The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
+ g_max: float, ArrayType, Callable
+ The synaptic strength (the maximum conductance). Default is 1.
+ alpha1: float, ArrayType
+ The conversion rate of g from inactive to active. Default 2 ms^-1.
+ beta1: float, ArrayType
+ The conversion rate of g from active to inactive. Default 0.01 ms^-1.
+ alpha2: float, ArrayType
+ The conversion rate of x from inactive to active. Default 1 ms^-1.
+ beta2: float, ArrayType
+ The conversion rate of x from active to inactive. Default 0.5 ms^-1.
+ name: str
+ The name of this synaptic projection.
+ method: str
+ The numerical integration methods.
+
+ References
+ ----------
+
+ .. [1] Devaney A J . Mathematical Foundations of Neuroscience[M].
+ Springer New York, 2010: 162.
+ .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and
+ Eric Gouaux. "Subunit arrangement and function in NMDA receptors."
+ Nature 438, no. 7065 (2005): 185-192.
+ .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New
+ England journal of medicine, 361(3), p.302.
+ .. [4] https://en.wikipedia.org/wiki/NMDA_receptor
+
+ """
+
+ def __init__(
+ self,
+ pre: NeuDyn,
+ post: NeuDyn,
+ conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
+ output: _SynOut = MgBlock(E=0.),
+ stp: Optional[_SynSTP] = None,
+ comp_method: str = 'dense',
+ g_max: Union[float, ArrayType, Callable] = 0.15,
+ delay_step: Union[int, ArrayType, Callable] = None,
+ alpha1: Union[float, ArrayType] = 2.,
+ beta1: Union[float, ArrayType] = 0.01,
+ alpha2: Union[float, ArrayType] = 1.,
+ beta2: Union[float, ArrayType] = 0.5,
+ T_0: Union[float, ArrayType] = 1.,
+ T_dur: Union[float, ArrayType] = 0.5,
+ method: str = 'exp_auto',
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ stop_spike_gradient: bool = False,
+ ):
+
+ # parameters
+ self.beta1 = beta1
+ self.beta2 = beta2
+ self.alpha1 = alpha1
+ self.alpha2 = alpha2
+ self.T_0 = T_0
+ self.T_dur = T_dur
+ if bm.size(alpha1) != 1:
+ raise ValueError(f'"alpha1" must be a scalar or a tensor with size of 1. But we got {alpha1}')
+ if bm.size(beta1) != 1:
+ raise ValueError(f'"beta1" must be a scalar or a tensor with size of 1. But we got {beta1}')
+ if bm.size(alpha2) != 1:
+ raise ValueError(f'"alpha2" must be a scalar or a tensor with size of 1. But we got {alpha2}')
+ if bm.size(beta2) != 1:
+ raise ValueError(f'"beta2" must be a scalar or a tensor with size of 1. But we got {beta2}')
+ if bm.size(T_0) != 1:
+ raise ValueError(f'"T_0" must be a scalar or a tensor with size of 1. But we got {T_0}')
+ if bm.size(T_dur) != 1:
+ raise ValueError(f'"T_dur" must be a scalar or a tensor with size of 1. But we got {T_dur}')
+ self.comp_method = comp_method
+ self.stop_spike_gradient = stop_spike_gradient
+
+ syn = _DelayedNMDA.desc(pre.size,
+ pre.keep_size,
+ mode=mode,
+ alpha1=alpha1,
+ beta1=beta1,
+ alpha2=alpha2,
+ beta2=beta2,
+ T=T_0,
+ T_dur=T_dur,
+ method=method,
+ stp=stp,
+ master=self)
+ super().__init__(pre=pre,
+ post=post,
+ syn=syn,
+ conn=conn,
+ output=output,
+ stp=stp,
+ comp_method=comp_method,
+ g_max=g_max,
+ delay_step=delay_step,
+ name=name,
+ mode=mode)
+
+ # copy the references
+ syn = self.pre.after_updates[self.proj._syn_id].syn.syn
+ self.g = syn.g
+ self.x = syn.x
+ self.spike_arrival_time = syn.spike_arrival_time
+
+ def update(self, pre_spike=None):
+ return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)
diff --git a/brainpy/_src/dynold/synapses/compat.py b/brainpy/_src/dynold/synapses/compat.py
new file mode 100644
index 000000000..e4b9483bb
--- /dev/null
+++ b/brainpy/_src/dynold/synapses/compat.py
@@ -0,0 +1,257 @@
+# -*- coding: utf-8 -*-
+
+import warnings
+from typing import Union, Dict, Callable
+
+from brainpy._src.connect import TwoEndConnector
+from brainpy._src.dynold.synouts import COBA, CUBA
+from brainpy._src.dynsys import NeuDyn
+from brainpy._src.initialize import Initializer
+from brainpy.types import ArrayType
+from .abstract_models import Delta, Exponential, DualExponential
+
+__all__ = [
+ 'DeltaSynapse',
+ 'ExpCUBA',
+ 'ExpCOBA',
+ 'DualExpCUBA',
+ 'DualExpCOBA',
+ 'AlphaCUBA',
+ 'AlphaCOBA',
+]
+
+
+class DeltaSynapse(Delta):
+ """Delta synapse.
+
+ .. deprecated:: 2.1.13
+ Please use "brainpy.synapses.Delta" instead.
+
+ """
+
+ def __init__(
+ self,
+ pre: NeuDyn,
+ post: NeuDyn,
+ conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
+ conn_type: str = 'sparse',
+ weights: Union[float, ArrayType, Initializer, Callable] = 1.,
+ delay_step: Union[float, ArrayType, Initializer, Callable] = None,
+ post_input_key: str = 'V',
+ post_has_ref: bool = False,
+ name: str = None,
+ ):
+ warnings.warn('Please use "brainpy.synapses.Delta" instead.', DeprecationWarning)
+ super().__init__(pre=pre,
+ post=post,
+ conn=conn,
+ output=CUBA(post_input_key),
+ name=name,
+ comp_method=conn_type,
+ g_max=weights,
+ delay_step=delay_step,
+ post_ref_key='refractory' if post_has_ref else None)
+
+
+class ExpCUBA(Exponential):
+ r"""Current-based exponential decay synapse model.
+
+ .. deprecated:: 2.1.13
+ Please use "brainpy.synapses.Exponential" instead.
+
+ """
+
+ def __init__(
+ self,
+ pre: NeuDyn,
+ post: NeuDyn,
+ conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
+ conn_type: str = 'sparse',
+ g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
+ delay_step: Union[int, ArrayType, Initializer, Callable] = None,
+ tau: Union[float, ArrayType] = 8.0,
+ name: str = None,
+ method: str = 'exp_auto',
+ ):
+ super().__init__(pre=pre,
+ post=post,
+ conn=conn,
+ name=name,
+ comp_method=conn_type,
+ g_max=g_max,
+ delay_step=delay_step,
+ tau=tau,
+ method=method,
+ output=CUBA())
+
+
+class ExpCOBA(Exponential):
+ """Conductance-based exponential decay synapse model.
+
+ .. deprecated:: 2.1.13
+ Please use "brainpy.synapses.Exponential" instead.
+ """
+
+ def __init__(
+ self,
+ pre: NeuDyn,
+ post: NeuDyn,
+ # connection
+ conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
+ conn_type: str = 'sparse',
+ # connection strength
+ g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
+ # synapse parameter
+ tau: Union[float, ArrayType] = 8.0,
+ E: Union[float, ArrayType] = 0.,
+ # synapse delay
+ delay_step: Union[int, ArrayType, Initializer, Callable] = None,
+ # others
+ method: str = 'exp_auto',
+ name: str = None
+ ):
+ super().__init__(pre=pre,
+ post=post,
+ conn=conn,
+ comp_method=conn_type,
+ g_max=g_max,
+ delay_step=delay_step,
+ tau=tau,
+ method=method,
+ name=name,
+ output=COBA(E=E))
+
+
+class DualExpCUBA(DualExponential):
+ r"""Current-based dual exponential synapse model.
+
+ .. deprecated:: 2.1.13
+ Please use "brainpy.synapses.DualExponential" instead.
+
+ """
+
+ def __init__(
+ self,
+ pre: NeuDyn,
+ post: NeuDyn,
+ conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
+ conn_type: str = 'dense',
+ g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
+ tau_decay: Union[float, ArrayType] = 10.0,
+ tau_rise: Union[float, ArrayType] = 1.,
+ delay_step: Union[int, ArrayType, Initializer, Callable] = None,
+ method: str = 'exp_auto',
+ name: str = None
+ ):
+ super().__init__(pre=pre,
+ post=post,
+ conn=conn,
+ comp_method=conn_type,
+ g_max=g_max,
+ tau_decay=tau_decay,
+ tau_rise=tau_rise,
+ delay_step=delay_step,
+ method=method,
+ name=name,
+ output=CUBA())
+
+
+class DualExpCOBA(DualExponential):
+ """Conductance-based dual exponential synapse model.
+
+
+ .. deprecated:: 2.1.13
+ Please use "brainpy.synapses.DualExponential" instead.
+
+ """
+
+ def __init__(
+ self,
+ pre: NeuDyn,
+ post: NeuDyn,
+ conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
+ conn_type: str = 'dense',
+ g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
+ delay_step: Union[int, ArrayType, Initializer, Callable] = None,
+ tau_decay: Union[float, ArrayType] = 10.0,
+ tau_rise: Union[float, ArrayType] = 1.,
+ E: Union[float, ArrayType] = 0.,
+ method: str = 'exp_auto',
+ name: str = None
+ ):
+ super().__init__(pre=pre,
+ post=post,
+ conn=conn,
+ comp_method=conn_type,
+ g_max=g_max,
+ tau_decay=tau_decay,
+ tau_rise=tau_rise,
+ delay_step=delay_step,
+ method=method,
+ name=name,
+ output=COBA(E=E))
+
+
+class AlphaCUBA(DualExpCUBA):
+ r"""Current-based alpha synapse model.
+
+ .. deprecated:: 2.1.13
+ Please use "brainpy.synapses.Alpha" instead.
+
+ """
+
+ def __init__(
+ self,
+ pre: NeuDyn,
+ post: NeuDyn,
+ conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
+ conn_type: str = 'dense',
+ g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
+ delay_step: Union[int, ArrayType, Initializer, Callable] = None,
+ tau_decay: Union[float, ArrayType] = 10.0,
+ method: str = 'exp_auto',
+ name: str = None
+ ):
+ super().__init__(pre=pre,
+ post=post,
+ conn=conn,
+ conn_type=conn_type,
+ delay_step=delay_step,
+ g_max=g_max,
+ tau_decay=tau_decay,
+ tau_rise=tau_decay,
+ method=method,
+ name=name)
+
+
+class AlphaCOBA(DualExpCOBA):
+ """Conductance-based alpha synapse model.
+
+ .. deprecated:: 2.1.13
+ Please use "brainpy.synapses.Alpha" instead.
+
+ """
+
+ def __init__(
+ self,
+ pre: NeuDyn,
+ post: NeuDyn,
+ conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
+ conn_type: str = 'dense',
+ g_max: Union[float, ArrayType, Callable, Initializer] = 1.,
+ delay_step: Union[int, ArrayType, Initializer, Callable] = None,
+ tau_decay: Union[float, ArrayType] = 10.0,
+ E: Union[float, ArrayType] = 0.,
+ method: str = 'exp_auto',
+ name: str = None
+ ):
+ super().__init__(pre=pre,
+ post=post,
+ conn=conn,
+ conn_type=conn_type,
+ delay_step=delay_step,
+ g_max=g_max, E=E,
+ tau_decay=tau_decay,
+ tau_rise=tau_decay,
+ method=method,
+ name=name)
diff --git a/brainpy/_src/synapses/learning_rules.py b/brainpy/_src/dynold/synapses/learning_rules.py
similarity index 77%
rename from brainpy/_src/synapses/learning_rules.py
rename to brainpy/_src/dynold/synapses/learning_rules.py
index e35bd6686..583a2c01b 100644
--- a/brainpy/_src/synapses/learning_rules.py
+++ b/brainpy/_src/dynold/synapses/learning_rules.py
@@ -1,14 +1,14 @@
# -*- coding: utf-8 -*-
-from typing import Union, Dict, Callable
+from typing import Union, Dict, Callable, Optional
-import jax.numpy as jnp
-
-import brainpy.math as bm
-from brainpy._src.dynsys import NeuGroup, TwoEndConn
-from brainpy._src.initialize import Initializer, delay as init_delay
-from brainpy._src.integrators import odeint, JointEq
from brainpy._src.connect import TwoEndConnector
+from brainpy._src.dyn import synapses
+from brainpy._src.dynold.synouts import CUBA
+from brainpy._src.dynold.synapses import _TwoEndConnAlignPre
+from brainpy._src.dynsys import NeuDyn, Sequential
+from brainpy._src.initialize import Initializer
+from brainpy._src.mixin import ParamDesc
from brainpy.types import ArrayType
__all__ = [
@@ -16,7 +16,14 @@
]
-class STP(TwoEndConn):
+class _STPModel(Sequential, ParamDesc):
+ def __init__(self, size, keep_size, tau, U, tau_f, tau_d, mode=None, method='exp_euler'):
+ stp = synapses.STP(size, keep_size, U=U, tau_f=tau_f, tau_d=tau_d, method=method, mode=mode)
+ exp = synapses.Expon(size, keep_size, tau=tau, method=method, mode=mode)
+ super().__init__(stp, exp)
+
+
+class STP(_TwoEndConnAlignPre):
r"""Short-term plasticity model.
**Model Descriptions**
@@ -176,8 +183,8 @@ class STP(TwoEndConn):
def __init__(
self,
- pre: NeuGroup,
- post: NeuGroup,
+ pre: NeuDyn,
+ post: NeuDyn,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
U: Union[float, ArrayType] = 0.15,
tau_f: Union[float, ArrayType] = 1500.,
@@ -186,11 +193,8 @@ def __init__(
A: Union[float, ArrayType] = 1.,
delay_step: Union[int, ArrayType, Initializer, Callable] = None,
method: str = 'exp_auto',
- name: str = None
+ name: Optional[str] = None
):
- super(STP, self).__init__(pre=pre, post=post, conn=conn, name=name)
- self.check_post_attrs('input')
-
# parameters
self.tau_d = tau_d
self.tau_f = tau_f
@@ -198,47 +202,28 @@ def __init__(
self.U = U
self.A = A
- # connections
- self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids')
+ syn = _STPModel.desc(pre.size,
+ pre.keep_size,
+ tau,
+ U,
+ tau_f,
+ tau_d,
+ mode=None,
+ method=method)
+
+ super().__init__(pre=pre,
+ post=post,
+ syn=syn,
+ conn=conn,
+ g_max=A,
+ output=CUBA(),
+ comp_method='sparse',
+ delay_step=delay_step,
+ name=name)
# variables
- self.num = len(self.pre_ids)
- self.x = bm.Variable(jnp.ones(self.num))
- self.u = bm.Variable(jnp.zeros(self.num))
- self.I = bm.Variable(jnp.zeros(self.num))
- self.delay_type, self.delay_step, self.delay_I = init_delay(delay_step, self.I)
-
- # integral
- self.integral = odeint(method=method, f=self.derivative)
-
- def reset(self):
- self.x.value = jnp.zeros(self.num)
- self.u.value = jnp.zeros(self.num)
- self.I.value = jnp.zeros(self.num)
- self.delay_I.reset(self.I)
-
- @property
- def derivative(self):
- dI = lambda I, t: -I / self.tau
- du = lambda u, t: - u / self.tau_f
- dx = lambda x, t: (1 - x) / self.tau_d
- return JointEq([dI, du, dx])
-
- def update(self, tdi):
- # delayed pre-synaptic spikes
- if self.delay_type == 'homo':
- delayed_I = self.delay_I(self.delay_step)
- elif self.delay_type == 'heter':
- delayed_I = self.delay_I(self.delay_step, jnp.arange(self.pre.num))
- else:
- delayed_I = self.I
- self.post.input += bm.syn2post(delayed_I, self.post_ids, self.post.num)
- self.I.value, u, x = self.integral(self.I, self.u, self.x, tdi.t, tdi.dt)
- syn_sps = bm.pre2syn(self.pre.spike, self.pre_ids)
- u = jnp.where(syn_sps, u + self.U * (1 - self.u), u)
- x = jnp.where(syn_sps, x - u * self.x, x)
- self.I.value = jnp.where(syn_sps, self.I + self.A * u * self.x, self.I.value)
- self.u.value = u
- self.x.value = x
- if self.delay_type in ['homo', 'heter']:
- self.delay_I.update(self.I)
+ syn = self.pre.after_updates[self.proj._syn_id].syn
+ self.x = syn[0].x
+ self.u = syn[0].u
+ self.I = syn[1].g
+
diff --git a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py
new file mode 100644
index 000000000..badb60832
--- /dev/null
+++ b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py
@@ -0,0 +1,126 @@
+# -*- coding: utf-8 -*-
+
+
+from absl.testing import parameterized
+
+import brainpy as bp
+import brainpy.math as bm
+from brainpy._src.dynold.synapses import abstract_models
+
+
+class Test_Abstract_Synapse(parameterized.TestCase):
+ @parameterized.product(
+ name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'],
+ stp=[None, bp.synplast.STD(), bp.synplast.STP()],
+ mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)]
+ )
+ def test_all2all_synapse(self, name, stp, mode):
+ bm.random.seed()
+ with bm.environment(mode=mode):
+ pre_neu = bp.neurons.LIF(5)
+ post_neu = bp.neurons.LIF(5)
+ syn_model = getattr(bp.synapses, name)
+ syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp)
+ net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
+
+ # 运行模拟
+ runner = bp.DSRunner(net,
+ monitors=['pre.V', 'syn.g', 'post.V'],
+ inputs=('pre.input', 35.))
+ runner(10.)
+
+ expected_shape = (100, 5)
+ if isinstance(mode, bm.BatchingMode):
+ expected_shape = (mode.batch_size, ) + expected_shape
+ self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'],
+ stp=[None, bp.synplast.STD(), bp.synplast.STP()],
+ mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)]
+ )
+ def test_one2one_synapse(self, name, stp, mode):
+ bm.random.seed()
+ with bm.environment(mode=mode):
+ pre_neu = bp.neurons.LIF(5)
+ post_neu = bp.neurons.LIF(5)
+ syn_model = getattr(abstract_models, name)
+ syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp)
+ net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
+
+ # 运行模拟
+ runner = bp.DSRunner(net,
+ monitors=['pre.V', 'syn.g', 'post.V'],
+ inputs=('pre.input', 35.))
+ runner(10.)
+
+ expected_shape = (100, 5)
+ if isinstance(mode, bm.BatchingMode):
+ expected_shape = (mode.batch_size, ) + expected_shape
+ self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ comp_type=['sparse', 'dense'],
+ name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'],
+ stp=[None, bp.synplast.STD(), bp.synplast.STP()],
+ mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)]
+ )
+ def test_sparse_synapse(self, comp_type, name, stp, mode):
+ bm.random.seed()
+ with bm.environment(mode=mode):
+ pre_neu = bp.neurons.LIF(5)
+ post_neu = bp.neurons.LIF(5)
+ syn_model = getattr(abstract_models, name)
+ syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp)
+ net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
+
+ # 运行模拟
+ runner = bp.DSRunner(net,
+ monitors=['pre.V', 'syn.g', 'post.V'],
+ inputs=('pre.input', 35.))
+ runner(10.)
+
+ expected_shape = (100, 5)
+ if isinstance(mode, bm.BatchingMode):
+ expected_shape = (mode.batch_size, ) + expected_shape
+ self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ post_ref_key=[None, 'refractory'],
+ stp=[None, bp.synplast.STD(), bp.synplast.STP()],
+ mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)]
+ )
+ def test_delta_synapse(self, post_ref_key, stp, mode):
+ bm.random.seed()
+ with bm.environment(mode=mode):
+ pre_neu = bp.neurons.LIF(5, ref_var=True)
+ post_neu = bp.neurons.LIF(3, ref_var=True)
+ syn = bp.synapses.Delta(pre_neu, post_neu,
+ conn=bp.conn.All2All(),
+ post_ref_key=post_ref_key,
+ stp=stp, )
+ net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
+
+ # 运行模拟
+ runner = bp.DSRunner(net,
+ monitors=['pre.V', 'post.V'],
+ inputs=('pre.input', 35.))
+ runner(10.)
+
+ pre_expected_shape = (100, 5)
+ post_expected_shape = (100, 3)
+ if isinstance(mode, bm.BatchingMode):
+ pre_expected_shape = (mode.batch_size,) + pre_expected_shape
+ post_expected_shape = (mode.batch_size,) + post_expected_shape
+ self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape)
+ self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape)
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py
new file mode 100644
index 000000000..395868092
--- /dev/null
+++ b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+
+
+from absl.testing import parameterized
+
+import brainpy as bp
+import brainpy.math as bm
+
+biological_models = [
+ bp.synapses.AMPA,
+ bp.synapses.GABAa,
+ bp.synapses.BioNMDA,
+]
+
+
+class Test_Biological_Synapse(parameterized.TestCase):
+ @parameterized.product(
+ synapse=biological_models,
+ delay_step=[None, 5, 1],
+ mode=[bm.NonBatchingMode(), bm.BatchingMode(5)],
+ stp=[None, bp.synplast.STP(), bp.synplast.STD()]
+ )
+ def test_all2all_synapse(self, synapse, delay_step, mode, stp):
+ bm.random.seed()
+ with bm.environment(mode=mode):
+ pre_neu = bp.neurons.LIF(5)
+ post_neu = bp.neurons.LIF(5)
+ syn = synapse(pre_neu, post_neu, conn=bp.conn.All2All(), delay_step=delay_step, stp=stp)
+ net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
+
+ # 运行模拟
+ runner = bp.DSRunner(net,
+ monitors=['pre.V', 'syn.g', 'post.V'],
+ inputs=('pre.input', 35.))
+ runner(10.)
+
+ expected_shape = (100, 5)
+ if isinstance(mode, bm.BatchingMode):
+ expected_shape = (mode.batch_size,) + expected_shape
+
+ self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ synapse=biological_models,
+ delay_step=[None, 10, 1],
+ mode=[bm.NonBatchingMode(), bm.BatchingMode(5), ],
+ stp=[None, bp.synplast.STP(), bp.synplast.STD()]
+ )
+ def test_one2one_synapse(self, synapse, delay_step, mode, stp):
+ bm.random.seed()
+ with bm.environment(mode=mode):
+ pre_neu = bp.neurons.LIF(5)
+ post_neu = bp.neurons.LIF(5)
+ syn = synapse(pre_neu, post_neu, conn=bp.conn.One2One(), delay_step=delay_step, stp=stp)
+ net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
+
+ # 运行模拟
+ runner = bp.DSRunner(net,
+ monitors=['pre.V', 'syn.g', 'post.V'],
+ inputs=('pre.input', 35.))
+ runner(10.)
+
+ expected_shape = (100, 5)
+ if isinstance(mode, bm.BatchingMode):
+ expected_shape = (mode.batch_size,) + expected_shape
+ self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ synapse=biological_models,
+ comp_method=['sparse', 'dense'],
+ delay_step=[None, 10, 1],
+ mode=[bm.NonBatchingMode(), bm.BatchingMode(5)],
+ stp=[None, bp.synplast.STP(), bp.synplast.STD()]
+ )
+ def test_sparse_synapse(self, synapse, comp_method, delay_step, mode, stp):
+ bm.random.seed()
+ with bm.environment(mode=mode):
+ pre_neu = bp.neurons.LIF(10)
+ post_neu = bp.neurons.LIF(10)
+ syn = synapse(pre_neu, post_neu, conn=bp.conn.FixedProb(0.5),
+ comp_method=comp_method, delay_step=delay_step,
+ stp=stp)
+ net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
+
+ # 运行模拟
+ runner = bp.DSRunner(net,
+ monitors=['pre.V', 'syn.g', 'post.V'],
+ inputs=('pre.input', 35.))
+ runner(10.)
+
+ expected_shape = (100, 10)
+ if isinstance(mode, bm.BatchingMode):
+ expected_shape = (mode.batch_size,) + expected_shape
+ self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape)
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/dynold/synapses/tests/test_learning_rule.py b/brainpy/_src/dynold/synapses/tests/test_learning_rule.py
new file mode 100644
index 000000000..8c1c9d049
--- /dev/null
+++ b/brainpy/_src/dynold/synapses/tests/test_learning_rule.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+
+
+import brainpy as bp
+import brainpy.math as bm
+from absl.testing import parameterized
+
+
+class Test_learning_rule(parameterized.TestCase):
+ @parameterized.product(
+ delay_step=[None, 5, 1],
+ mode=[bm.NonBatchingMode(), bm.BatchingMode(5), bm.TrainingMode(5)]
+ )
+ def test_learning_rule(self, delay_step, mode):
+ bm.random.seed()
+ with bm.environment(mode=mode):
+ neu1 = bp.neurons.LIF(5)
+ neu2 = bp.neurons.LIF(5)
+ syn1 = bp.synapses.STP(neu1, neu2, bp.connect.All2All(), U=0.1, tau_d=10, tau_f=100.,
+ delay_step=delay_step)
+ net = bp.Network(pre=neu1, syn=syn1, post=neu2)
+
+ runner = bp.DSRunner(net, inputs=[('pre.input', 28.)], monitors=['syn.I', 'syn.u', 'syn.x'])
+ runner.run(10.)
+
+ expected_shape = (100, 5)
+ if isinstance(mode, bm.BatchingMode):
+ expected_shape = (mode.batch_size,) + expected_shape
+ self.assertTupleEqual(runner.mon['syn.I'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['syn.u'].shape, expected_shape)
+ self.assertTupleEqual(runner.mon['syn.x'].shape, expected_shape)
+ bm.clear_buffer_memory()
+
diff --git a/brainpy/_src/synouts/__init__.py b/brainpy/_src/dynold/synouts/__init__.py
similarity index 100%
rename from brainpy/_src/synouts/__init__.py
rename to brainpy/_src/dynold/synouts/__init__.py
diff --git a/brainpy/_src/synouts/conductances.py b/brainpy/_src/dynold/synouts/conductances.py
similarity index 90%
rename from brainpy/_src/synouts/conductances.py
rename to brainpy/_src/dynold/synouts/conductances.py
index 2b77e67a7..9c0562fbf 100644
--- a/brainpy/_src/synouts/conductances.py
+++ b/brainpy/_src/dynold/synouts/conductances.py
@@ -2,19 +2,18 @@
from typing import Union, Callable, Optional
-from brainpy.math import Variable
-from brainpy._src.dynsys import SynOut
+from brainpy._src.dynold.synapses.base import _SynOut
from brainpy._src.initialize import parameter, Initializer
+from brainpy.math import Variable
from brainpy.types import ArrayType
-
__all__ = [
'COBA',
'CUBA',
]
-class CUBA(SynOut):
+class CUBA(_SynOut):
r"""Current-based synaptic output.
Given the conductance, this model outputs the post-synaptic current with a identity function:
@@ -40,13 +39,13 @@ def __init__(
name: str = None,
):
self._target_var = target_var
- super(CUBA, self).__init__(name=name, target_var=target_var)
+ super().__init__(name=name, target_var=target_var)
def clone(self):
return CUBA(target_var=self._target_var)
-class COBA(SynOut):
+class COBA(_SynOut):
r"""Conductance-based synaptic output.
Given the synaptic conductance, the model output the post-synaptic current with
@@ -74,7 +73,7 @@ def __init__(
membrane_var: Union[str, Variable] = 'V',
name: str = None,
):
- super(COBA, self).__init__(name=name, target_var=target_var)
+ super().__init__(name=name, target_var=target_var)
self._E = E
self._target_var = target_var
self._membrane_var = membrane_var
@@ -85,7 +84,7 @@ def clone(self):
membrane_var=self._membrane_var)
def register_master(self, master):
- super(COBA, self).register_master(master)
+ super().register_master(master)
# reversal potential
self.E = parameter(self._E, self.master.post.num, allow_none=False)
diff --git a/brainpy/_src/synouts/ions.py b/brainpy/_src/dynold/synouts/ions.py
similarity index 94%
rename from brainpy/_src/synouts/ions.py
rename to brainpy/_src/dynold/synouts/ions.py
index 46faacef0..da5b511d7 100644
--- a/brainpy/_src/synouts/ions.py
+++ b/brainpy/_src/dynold/synouts/ions.py
@@ -5,17 +5,16 @@
import jax.numpy as jnp
import brainpy.math as bm
-from brainpy._src.dynsys import SynOut
+from brainpy._src.dynold.synapses.base import _SynOut
from brainpy._src.initialize import parameter, Initializer
from brainpy.types import ArrayType
-
__all__ = [
'MgBlock',
]
-class MgBlock(SynOut):
+class MgBlock(_SynOut):
r"""Synaptic output based on Magnesium blocking.
Given the synaptic conductance, the model output the post-synaptic current with
@@ -56,7 +55,7 @@ def __init__(
membrane_var: Union[str, bm.Variable] = 'V',
name: str = None,
):
- super(MgBlock, self).__init__(name=name, target_var=target_var)
+ super().__init__(name=name, target_var=target_var)
self._E = E
self._cc_Mg = cc_Mg
self._alpha = alpha
@@ -65,7 +64,7 @@ def __init__(
self._membrane_var = membrane_var
def register_master(self, master):
- super(MgBlock, self).register_master(master)
+ super().register_master(master)
self.E = parameter(self._E, self.master.post.num, allow_none=False)
self.cc_Mg = parameter(self._cc_Mg, self.master.post.num, allow_none=False)
diff --git a/brainpy/_src/synplast/__init__.py b/brainpy/_src/dynold/synplast/__init__.py
similarity index 100%
rename from brainpy/_src/synplast/__init__.py
rename to brainpy/_src/dynold/synplast/__init__.py
diff --git a/brainpy/_src/synplast/short_term_plasticity.py b/brainpy/_src/dynold/synplast/short_term_plasticity.py
similarity index 88%
rename from brainpy/_src/synplast/short_term_plasticity.py
rename to brainpy/_src/dynold/synplast/short_term_plasticity.py
index f933cf321..da3428662 100644
--- a/brainpy/_src/synplast/short_term_plasticity.py
+++ b/brainpy/_src/dynold/synplast/short_term_plasticity.py
@@ -4,7 +4,8 @@
import jax.numpy as jnp
-from brainpy._src.dynsys import SynSTP
+from brainpy._src.context import share
+from brainpy._src.dynold.synapses.base import _SynSTP
from brainpy._src.initialize import variable
from brainpy._src.integrators import odeint, JointEq
from brainpy.check import is_float
@@ -16,7 +17,7 @@
]
-class STD(SynSTP):
+class STD(_SynSTP):
r"""Synaptic output with short-term depression.
This model filters the synaptic current by the following equation:
@@ -69,17 +70,18 @@ def __init__(
# integral function
self.integral = odeint(lambda x, t: (1 - x) / self.tau, method=self.method)
- def register_master(self, master):
- super(STD, self).register_master(master)
+ def clone(self):
+ return STD(tau=self.tau, U=self.U, method=self.method)
- # variables
+ def register_master(self, master):
+ super().register_master(master)
self.x = variable(jnp.ones, self.master.mode, self.master.pre.num)
def reset_state(self, batch_size=None):
self.x.value = variable(jnp.ones, batch_size, self.master.pre.num)
- def update(self, tdi, pre_spike):
- x = self.integral(self.x.value, tdi['t'], tdi['dt'])
+ def update(self, pre_spike):
+ x = self.integral(self.x.value, share['t'], share['dt'])
self.x.value = jnp.where(pre_spike, x - self.U * self.x, x)
def filter(self, g):
@@ -88,7 +90,7 @@ def filter(self, g):
return g * self.x
-class STP(SynSTP):
+class STP(_SynSTP):
r"""Synaptic output with short-term plasticity.
This model filters the synaptic currents according to two variables: :math:`u` and :math:`x`.
@@ -153,10 +155,11 @@ def __init__(
# integral function
self.integral = odeint(self.derivative, method=self.method)
- def register_master(self, master):
- super(STP, self).register_master(master)
+ def clone(self):
+ return STP(tau_f=self.tau_f, tau_d=self.tau_d, U=self.U, method=self.method)
- # variables
+ def register_master(self, master):
+ super().register_master(master)
self.x = variable(jnp.ones, self.master.mode, self.master.pre.num)
self.u = variable(lambda s: jnp.ones(s) * self.U, self.master.mode, self.master.pre.num)
@@ -168,10 +171,10 @@ def reset_state(self, batch_size=None):
def derivative(self):
du = lambda u, t: self.U - u / self.tau_f
dx = lambda x, t: (1 - x) / self.tau_d
- return JointEq([du, dx])
+ return JointEq(du, dx)
- def update(self, tdi, pre_spike):
- u, x = self.integral(self.u.value, self.x.value, tdi['t'], tdi['dt'])
+ def update(self, pre_spike):
+ u, x = self.integral(self.u.value, self.x.value, share['t'], share['dt'])
u = jnp.where(pre_spike, u + self.U * (1 - self.u), u)
x = jnp.where(pre_spike, x - u * self.x, x)
self.x.value = x
@@ -181,4 +184,3 @@ def filter(self, g):
if jnp.shape(g) != self.x.shape:
raise ValueError('Shape does not match.')
return g * self.x * self.u
-
diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py
index 1eb5bb3cd..5465d1898 100644
--- a/brainpy/_src/dynsys.py
+++ b/brainpy/_src/dynsys.py
@@ -2,47 +2,28 @@
import collections
import gc
-from typing import Union, Dict, Callable, Sequence, Optional, Tuple
+import inspect
+from typing import Union, Dict, Callable, Sequence, Optional, Tuple, Any
-import jax
-import jax.numpy as jnp
import numpy as np
-from brainpy import tools
-from brainpy._src import math as bm
-from brainpy._src.connect import TwoEndConnector, MatConn, IJConn, One2One, All2All
-from brainpy._src.initialize import Initializer, parameter, variable, Uniform, noise as init_noise
-from brainpy._src.integrators import odeint, sdeint
-from brainpy._src.math.object_transform.variables import Variable, VariableView
-from brainpy._src.math.object_transform.base import BrainPyObject, Collector
+from brainpy import tools, math as bm
+from brainpy._src.initialize import parameter, variable_
+from brainpy._src.mixin import AutoDelaySupp, ParamDesc, Container, DelayRegister, global_delay_data
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape
share = None
__all__ = [
- # general class
+ # general
'DynamicalSystem',
- 'DynamicalSystemNS',
# containers
- 'Container', 'Network', 'Sequential', 'System',
+ 'DynSysGroup', 'Network', 'Sequential',
- # channel models
- 'Channel',
-
- # neuron models
- 'NeuGroup', 'CondNeuGroup', 'NeuGroupNS',
-
- # synapse models
- 'SynConn',
- 'TwoEndConn',
- 'SynOut', 'NullSynOut',
- 'SynSTP',
- 'SynLTP',
-
- # slice
- 'DSView', 'NeuGroupView',
+ # base classes
+ 'NeuDyn', 'SynDyn', 'IonChaDyn',
]
SLICE_VARS = 'slice_vars'
@@ -88,23 +69,26 @@ def update(self, x):
return func
-class DynamicalSystem(BrainPyObject):
+class DynamicalSystem(bm.BrainPyObject, DelayRegister):
"""Base Dynamical System class.
.. note::
In general, every instance of :py:class:`~.DynamicalSystem` implemented in
BrainPy only defines the evolving function at each time step :math:`t`.
- Each subclass of :py:class:`~.DynamicalSystem` may have multiple step functions.
- For instance, all our implemented neuron model define two step functions:
-
- - ``.update()`` for the logic updating
- - ``clear_input()`` for clear all accumulated inputs at this time step.
-
If users want to define the logic of running models across multiple steps,
we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`,
:py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`.
-
+
+ To be compatible with previous APIs, :py:class:`~.DynamicalSystem` inherits
+ from the :py:class:`~.DelayRegister`. It's worthy to note that the methods of
+ :py:class:`~.DelayRegister` will be removed in the future, including:
+
+ - ``.register_delay()``
+ - ``.get_delay_data()``
+ - ``.update_local_delays()``
+ - ``.reset_local_delays()``
+
Parameters
----------
name : optional, str
@@ -116,13 +100,6 @@ class DynamicalSystem(BrainPyObject):
supported_modes: Optional[Sequence[bm.Mode]] = None
'''Supported computing modes.'''
- _pass_shared_args: bool = True
-
- global_delay_data: Dict[str, Tuple[Union[bm.LengthDelay, None], Variable]] = dict()
- '''Global delay data, which stores the delay variables and corresponding delay targets.
- This variable is useful when the same target variable is used in multiple mappings,
- as it can reduce the duplicate delay variable registration.'''
-
def __init__(
self,
name: Optional[str] = None,
@@ -141,11 +118,46 @@ def __init__(
f'which are parents of {self.supported_modes}, '
f'but we got {self.mode}.')
- # local delay variables
- self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector()
+ # local delay variables:
+ # Compatible for ``DelayRegister``
+ self.local_delay_vars: Dict = bm.node_dict()
+
+ # the before- / after-updates used for computing
+ # added after the version of 2.4.3
+ self.before_updates: Dict[str, Callable] = bm.node_dict()
+ self.after_updates: Dict[str, Callable] = bm.node_dict()
# super initialization
- BrainPyObject.__init__(self, name=name)
+ super().__init__(name=name)
+
+ def update(self, *args, **kwargs):
+ """The function to specify the updating rule.
+
+ Assume any dynamical system depends on the shared variables (`sha`),
+ like time variable ``t``, the step precision ``dt``, and the time step `i`.
+ """
+ raise NotImplementedError('Must implement "update" function by subclass self.')
+
+ def reset(self, *args, **kwargs):
+ """Reset function which reset the whole variables in the model.
+ """
+ self.reset_state(*args, **kwargs)
+
+ def reset_state(self, *args, **kwargs):
+ """Reset function which reset the states in the model.
+ """
+ child_nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
+ if len(child_nodes) > 0:
+ for node in child_nodes.values():
+ node.reset_state(*args, **kwargs)
+ self.reset_local_delays(child_nodes)
+ else:
+ raise NotImplementedError('Must implement "reset_state" function by subclass self. '
+ f'Error of {self.name}')
+
+ def clear_input(self):
+ """Clear the input at the current time step."""
+ pass
@property
def mode(self) -> bm.Mode:
@@ -159,256 +171,119 @@ def mode(self, value):
f'but we got {type(value)}: {value}')
self._mode = value
- def __repr__(self):
- return f'{self.__class__.__name__}(name={self.name}, mode={self.mode})'
-
- def __call__(self, *args, **kwargs):
- """The shortcut to call ``update`` methods."""
+ def _compatible_update(self, *args, **kwargs):
global share
if share is None:
from brainpy._src.context import share
-
- try:
- if self._pass_shared_args:
- if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'):
- if len(args) and isinstance(args[0], dict):
- share.save(**args[0])
- return self.update(*args[1:], **kwargs)
- else:
- return self.update(*args, **kwargs)
+ update_fun = super().__getattribute__('update')
+ update_args = tuple(inspect.signature(update_fun).parameters.values())
+
+ if len(update_args) and update_args[0].name in ['tdi', 'sh', 'sha']:
+ if len(args) > 0:
+ if isinstance(args[0], dict):
+ # define:
+ # update(tdi, *args, **kwargs)
+ # call:
+ # update(tdi, *args, **kwargs)
+ ret = update_fun(*args, **kwargs)
+ # TODO: deprecation
else:
- if len(args) and isinstance(args[0], dict):
- return self.update(*args, **kwargs)
- else:
- # If first argument is not shared argument,
- # we should get the shared arguments from the global context.
- # However, users should set and update shared arguments
- # in the global context when using this mode.
- return self.update(share.get_shargs(), *args, **kwargs)
+ # define:
+ # update(tdi, *args, **kwargs)
+ # call:
+ # update(*args, **kwargs)
+ ret = update_fun(share.get_shargs(), *args, **kwargs)
else:
- if len(args) and isinstance(args[0], dict): # it may be shared arguments
- share.save(**args[0])
- return self.update(*args[1:], **kwargs)
+ if update_args[0].name in kwargs:
+ # define:
+ # update(tdi, *args, **kwargs)
+ # call:
+ # update(tdi=??, **kwargs)
+ ret = update_fun(**kwargs)
else:
- return self.update(*args, **kwargs)
- except Exception as e:
- raise RuntimeError(f'Error occurs when running {self.name}: {self}') from e
+ # define:
+ # update(tdi, *args, **kwargs)
+ # call:
+ # update(**kwargs)
+ ret = update_fun(share.get_shargs(), *args, **kwargs)
+ return ret
- def register_delay(
- self,
- identifier: str,
- delay_step: Optional[Union[int, ArrayType, Callable, Initializer]],
- delay_target: Variable,
- initial_delay_data: Union[Initializer, Callable, ArrayType, float, int, bool] = None,
- ):
- """Register delay variable.
-
- Parameters
- ----------
- identifier: str
- The delay variable name.
- delay_step: Optional, int, ArrayType, callable, Initializer
- The number of the steps of the delay.
- delay_target: Variable
- The target variable for delay.
- initial_delay_data: float, int, ArrayType, callable, Initializer
- The initializer for the delay data.
-
- Returns
- -------
- delay_step: int, ArrayType
- The number of the delay steps.
- """
- # delay steps
- if delay_step is None:
- delay_type = 'none'
- elif isinstance(delay_step, (int, np.integer, jnp.integer)):
- delay_type = 'homo'
- elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)):
- if delay_step.size == 1 and delay_step.ndim == 0:
- delay_type = 'homo'
- else:
- delay_type = 'heter'
- delay_step = bm.asarray(delay_step)
- elif callable(delay_step):
- delay_step = parameter(delay_step, delay_target.shape, allow_none=False)
- delay_type = 'heter'
- else:
- raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support '
- f'integer, array of integers, callable function, brainpy.init.Initializer.')
- if delay_type == 'heter':
- if delay_step.dtype not in [bm.int32, bm.int64]:
- raise ValueError('Only support delay steps of int32, int64. If your '
- 'provide delay time length, please divide the "dt" '
- 'then provide us the number of delay steps.')
- if delay_target.shape[0] != delay_step.shape[0]:
- raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}')
- if delay_type != 'none':
- max_delay_step = int(bm.max(delay_step))
-
- # delay target
- if delay_type != 'none':
- if not isinstance(delay_target, Variable):
- raise ValueError(f'"delay_target" must be an instance of Variable, but we got {type(delay_target)}')
-
- # delay variable
- if delay_type != 'none':
- if identifier not in self.global_delay_data:
- delay = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
- self.global_delay_data[identifier] = (delay, delay_target)
- self.local_delay_vars[identifier] = delay
+ try:
+ ba = inspect.signature(update_fun).bind(*args, **kwargs)
+ except TypeError:
+ if len(args) and isinstance(args[0], dict):
+ # user define ``update()`` function which does not receive the shared argument,
+ # but do provide these shared arguments when calling ``update()`` function
+ # -----
+ # change
+ # update(tdi, *args, **kwargs)
+ # as
+ # update(*args, **kwargs)
+ share.save(**args[0])
+ ret = update_fun(*args[1:], **kwargs)
+ return ret
else:
- delay = self.global_delay_data[identifier][0]
- if delay is None:
- delay = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
- self.global_delay_data[identifier] = (delay, delay_target)
- self.local_delay_vars[identifier] = delay
- elif delay.num_delay_step - 1 < max_delay_step:
- self.global_delay_data[identifier][0].reset(delay_target, max_delay_step, initial_delay_data)
+ # user define ``update()`` function which receives the shared argument,
+ # but not provide these shared arguments when calling ``update()`` function
+ # -----
+ # change
+ # update(*args, **kwargs)
+ # as
+ # update(tdi, *args, **kwargs)
+ ret = update_fun(share.get_shargs(), *args, **kwargs)
+ return ret
else:
- if identifier not in self.global_delay_data:
- self.global_delay_data[identifier] = (None, delay_target)
- self.register_implicit_nodes(self.local_delay_vars)
- return delay_step
-
- def get_delay_data(
- self,
- identifier: str,
- delay_step: Optional[Union[int, bm.Array, jax.Array]],
- *indices: Union[int, slice, bm.Array, jax.Array],
- ):
- """Get delay data according to the provided delay steps.
-
- Parameters
- ----------
- identifier: str
- The delay variable name.
- delay_step: Optional, int, ArrayType
- The delay length.
- indices: optional, int, slice, ArrayType
- The indices of the delay.
-
- Returns
- -------
- delay_data: ArrayType
- The delay data at the given time.
- """
- if delay_step is None:
- return self.global_delay_data[identifier][1].value
-
- if identifier in self.global_delay_data:
- if bm.ndim(delay_step) == 0:
- return self.global_delay_data[identifier][0](delay_step, *indices)
- else:
- if len(indices) == 0:
- indices = (bm.arange(delay_step.size),)
- return self.global_delay_data[identifier][0](delay_step, *indices)
-
- elif identifier in self.local_delay_vars:
- if bm.ndim(delay_step) == 0:
- return self.local_delay_vars[identifier](delay_step)
- else:
- if len(indices) == 0:
- indices = (bm.arange(delay_step.size),)
- return self.local_delay_vars[identifier](delay_step, *indices)
+ return update_fun(*args, **kwargs)
+ def __getattribute__(self, item):
+ if item == 'update':
+ return self._compatible_update # update function compatible with previous ``update()`` function
else:
- raise ValueError(f'{identifier} is not defined in delay variables.')
-
- def update(self, *args, **kwargs):
- """The function to specify the updating rule.
-
- Assume any dynamical system depends on the shared variables (`sha`),
- like time variable ``t``, the step precision ``dt``, and the time step `i`.
- """
- raise NotImplementedError('Must implement "update" function by subclass self.')
+ return super().__getattribute__(item)
- def reset(self, *args, **kwargs):
- """Reset function which reset the whole variables in the model.
- """
- self.reset_state(*args, **kwargs)
+ def _get_update_fun(self):
+ return object.__getattribute__(self, 'update')
- def reset_state(self, *args, **kwargs):
- """Reset function which reset the states in the model.
- """
- child_nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
- if len(child_nodes) > 0:
- for node in child_nodes.values():
- node.reset_state(*args, **kwargs)
- self.reset_local_delays(child_nodes)
- else:
- raise NotImplementedError('Must implement "reset_state" function by subclass self. '
- f'Error of {self.name}')
+ def __repr__(self):
+ return f'{self.name}(mode={self.mode})'
- def update_local_delays(self, nodes: Union[Sequence, Dict] = None):
- """Update local delay variables.
+ def __call__(self, *args, **kwargs):
+ """The shortcut to call ``update`` methods."""
- This function should be called after updating neuron groups or delay sources.
- For example, in a network model,
+ # update ``before_updates``
+ for model in self.before_updates.values():
+ model()
+ # update the model self
+ ret = self.update(*args, **kwargs)
- Parameters
- ----------
- nodes: sequence, dict
- The nodes to update their delay variables.
- """
- # update delays
- if nodes is None:
- nodes = tuple(self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values())
- elif isinstance(nodes, DynamicalSystem):
- nodes = (nodes,)
- elif isinstance(nodes, dict):
- nodes = tuple(nodes.values())
- if not isinstance(nodes, (tuple, list)):
- raise ValueError('Please provide nodes as a list/tuple/dict of DynamicalSystem.')
- for node in nodes:
- for name in node.local_delay_vars:
- delay = self.global_delay_data[name][0]
- target = self.global_delay_data[name][1]
- delay.update(target.value)
-
- def reset_local_delays(self, nodes: Union[Sequence, Dict] = None):
- """Reset local delay variables.
-
- Parameters
- ----------
- nodes: sequence, dict
- The nodes to Reset their delay variables.
- """
- # reset delays
- if nodes is None:
- nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values()
- elif isinstance(nodes, dict):
- nodes = nodes.values()
- for node in nodes:
- for name in node.local_delay_vars:
- delay = self.global_delay_data[name][0]
- target = self.global_delay_data[name][1]
- delay.reset(target.value)
+ # update ``after_updates``
+ for model in self.after_updates.values():
+ model(ret)
+ return ret
def __del__(self):
"""Function for handling `del` behavior.
This function is used to pop out the variables which registered in global delay data.
"""
- if hasattr(self, 'local_delay_vars'):
- for key in tuple(self.local_delay_vars.keys()):
- val = self.global_delay_data.pop(key)
- del val
- val = self.local_delay_vars.pop(key)
- del val
- if hasattr(self, 'implicit_nodes'):
- for key in tuple(self.implicit_nodes.keys()):
- del self.implicit_nodes[key]
- if hasattr(self, 'implicit_vars'):
- for key in tuple(self.implicit_vars.keys()):
- del self.implicit_vars[key]
- for key in tuple(self.__dict__.keys()):
- del self.__dict__[key]
- gc.collect()
-
- def clear_input(self):
- pass
+ try:
+ if hasattr(self, 'local_delay_vars'):
+ for key in tuple(self.local_delay_vars.keys()):
+ val = global_delay_data.pop(key)
+ del val
+ val = self.local_delay_vars.pop(key)
+ del val
+ if hasattr(self, 'implicit_nodes'):
+ for key in tuple(self.implicit_nodes.keys()):
+ del self.implicit_nodes[key]
+ if hasattr(self, 'implicit_vars'):
+ for key in tuple(self.implicit_vars.keys()):
+ del self.implicit_vars[key]
+ for key in tuple(self.__dict__.keys()):
+ del self.__dict__[key]
+ finally:
+ gc.collect()
def __rrshift__(self, other):
"""Support using right shift operator to call modules.
@@ -420,103 +295,41 @@ def __rrshift__(self, other):
>>> x = bp.math.random.rand((10, 10))
>>> l = bp.layers.Activation(bm.tanh)
>>> y = x >> l
-
"""
return self.__call__(other)
-class Container(DynamicalSystem):
- """Container object which is designed to add other instances of DynamicalSystem.
+class DynSysGroup(DynamicalSystem, Container):
+ """A group of :py:class:`~.DynamicalSystem`s in which the updating order does not matter.
- Parameters
- ----------
- name : str, optional
- The object name.
- mode: Mode
- The mode which controls the model computation.
- must_be_dynsys_subclass: bool
- Child classes must be the subclass of :py:class:`DynamicalSystem`.
+ Args:
+ children_as_tuple: The children objects.
+ children_as_dict: The children objects.
+ name: The object name.
+ mode: The mode which controls the model computation.
+ child_type: The type of the children object. Default is :py:class:`DynamicalSystem`.
"""
def __init__(
self,
- *dynamical_systems_as_tuple,
- name: str = None,
- mode: bm.Mode = None,
- must_be_dynsys_subclass: bool = True,
- **dynamical_systems_as_dict
+ *children_as_tuple,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ child_type: type = DynamicalSystem,
+ **children_as_dict
):
- super(Container, self).__init__(name=name, mode=mode)
-
- if must_be_dynsys_subclass:
- parent = DynamicalSystem
- parent_name = DynamicalSystem.__name__
- else:
- parent = bm.BrainPyObject
- parent_name = bm.BrainPyObject.__name__
-
- # add tuple-typed components
- for module in dynamical_systems_as_tuple:
- if isinstance(module, parent):
- self.implicit_nodes[module.name] = module
- elif isinstance(module, (list, tuple)):
- for m in module:
- if not isinstance(m, parent):
- raise ValueError(f'Should be instance of {parent_name}. '
- f'But we got {type(m)}')
- self.implicit_nodes[m.name] = m
- elif isinstance(module, dict):
- for k, v in module.items():
- if not isinstance(v, parent):
- raise ValueError(f'Should be instance of {parent_name}. '
- f'But we got {type(v)}')
- self.implicit_nodes[k] = v
- else:
- raise ValueError(f'Cannot parse sub-systems. They should be {parent_name} '
- f'or a list/tuple/dict of {parent_name}.')
- # add dict-typed components
- for k, v in dynamical_systems_as_dict.items():
- if not isinstance(v, parent):
- raise ValueError(f'Should be instance of {parent_name}. '
- f'But we got {type(v)}')
- self.implicit_nodes[k] = v
+ super().__init__(name=name, mode=mode)
- def __repr__(self):
- cls_name = self.__class__.__name__
- indent = ' ' * len(cls_name)
- child_str = [tools.repr_context(repr(val), indent) for val in self.implicit_nodes.values()]
- string = ", \n".join(child_str)
- return f'{cls_name}({string})'
+ self.children = bm.node_dict(self.format_elements(child_type, *children_as_tuple, **children_as_dict))
- def update(self, tdi, *args, **kwargs):
+ def update(self):
"""Update function of a container.
In this update function, the update functions in children systems are
iteratively called.
-
- Parameters
- ----------
- tdi: dict
- The shared arguments including `t` the time, `dt` the time step, `t` the running index.
"""
- nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
- for node in nodes.values():
- node(tdi)
-
- def __getitem__(self, item):
- """Overwrite the slice access (`self['']`). """
- if item in self.implicit_nodes:
- return self.implicit_nodes[item]
- else:
- raise ValueError(f'Unknown item {item}, we only found {list(self.implicit_nodes.keys())}')
-
- def __getattr__(self, item):
- """Overwrite the dot access (`self.`). """
- child_ds = super(Container, self).__getattribute__('implicit_nodes')
- if item in child_ds:
- return child_ds[item]
- else:
- return super(Container, self).__getattribute__(item)
+ for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values():
+ node()
def clear_input(self):
"""Clear inputs in the children classes."""
@@ -524,34 +337,10 @@ def clear_input(self):
node.clear_input()
-
-class Network(Container):
- """Base class to model network objects, an alias of Container.
-
- Network instantiates a network, which is aimed to load
- neurons, synapses, and other brain objects.
-
- Parameters
- ----------
- name : str, Optional
- The network name.
- monitors : optional, list of str, tuple of str
- The items to monitor.
- ds_tuple :
- A list/tuple container of dynamical system.
- ds_dict :
- A dict container of dynamical system.
+class Network(DynSysGroup):
+ """A group of :py:class:`~.DynamicalSystem`s which defines the nodes and edges in a network.
"""
- def __init__(
- self,
- *ds_tuple,
- name: str = None,
- mode: bm.Mode = None,
- **ds_dict
- ):
- super(Network, self).__init__(*ds_tuple, name=name, mode=mode, **ds_dict)
-
@not_pass_shared
def update(self, *args, **kwargs):
"""Step function of a network.
@@ -559,58 +348,175 @@ def update(self, *args, **kwargs):
In this update function, the update functions in children systems are
iteratively called.
"""
- nodes = self.nodes(level=1, include_self=False)
- nodes = nodes.subset(DynamicalSystem)
- nodes = nodes.unique()
- neuron_groups = nodes.subset(NeuGroup)
- synapse_groups = nodes.subset(SynConn)
- ds_views = nodes.subset(DSView)
- other_nodes = nodes - neuron_groups - synapse_groups - ds_views
-
- # shared arguments
-
- # update synapse nodes
- for node in synapse_groups.values():
- node()
+ nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().not_subset(DynView)
- # update neuron nodes
- for node in neuron_groups.values():
+ # update nodes of projections
+ for node in nodes.subset(Projection).values():
node()
- # update other types of nodes
- for node in other_nodes.values():
+ # update nodes of dynamics
+ for node in nodes.subset(Dynamics).values():
node()
- # update delays
- self.update_local_delays(nodes)
+ # update nodes with other types, including delays, ...
+ for node in nodes.not_subset(Dynamics).not_subset(Projection).values():
+ node()
def reset_state(self, batch_size=None):
- nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
- neuron_groups = nodes.subset(NeuGroup)
- synapse_groups = nodes.subset(SynConn)
+ nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().not_subset(DynView)
- # reset neuron nodes
- for node in neuron_groups.values():
+ # reset dynamics
+ for node in nodes.subset(Dynamics).values():
node.reset_state(batch_size)
- # reset synapse nodes
- for node in synapse_groups.values():
+ # reset projections
+ for node in nodes.subset(Projection).values():
node.reset_state(batch_size)
- # reset other types of nodes
- for node in (nodes - neuron_groups - synapse_groups).values():
+ # reset other types of nodes, including delays, ...
+ for node in nodes.not_subset(Dynamics).not_subset(Projection).values():
node.reset_state(batch_size)
- # reset delays
- self.reset_local_delays(nodes)
+
+class Sequential(DynamicalSystem, AutoDelaySupp):
+ """A sequential `input-output` module.
+
+ Modules will be added to it in the order they are passed in the
+ constructor. Alternatively, an ``dict`` of modules can be
+ passed in. The ``update()`` method of ``Sequential`` accepts any
+ input and forwards it to the first module it contains. It then
+ "chains" outputs to inputs sequentially for each subsequent module,
+ finally returning the output of the last module.
+
+ The value a ``Sequential`` provides over manually calling a sequence
+ of modules is that it allows treating the whole container as a
+ single module, such that performing a transformation on the
+ ``Sequential`` applies to each of the modules it stores (which are
+ each a registered submodule of the ``Sequential``).
+
+ What's the difference between a ``Sequential`` and a
+ :py:class:`Container`? A ``Container`` is exactly what it
+ sounds like--a container to store :py:class:`DynamicalSystem` s!
+ On the other hand, the layers in a ``Sequential`` are connected
+ in a cascading way.
+
+ Examples
+ --------
+
+ >>> import brainpy as bp
+ >>> import brainpy.math as bm
+ >>>
+ >>> # composing ANN models
+ >>> l = bp.Sequential(bp.layers.Dense(100, 10),
+ >>> bm.relu,
+ >>> bp.layers.Dense(10, 2))
+ >>> l({}, bm.random.random((256, 100)))
+ >>>
+ >>> # Using Sequential with Dict. This is functionally the
+ >>> # same as the above code
+ >>> l = bp.Sequential(l1=bp.layers.Dense(100, 10),
+ >>> l2=bm.relu,
+ >>> l3=bp.layers.Dense(10, 2))
+ >>> l({}, bm.random.random((256, 100)))
+
+
+ Args:
+ modules_as_tuple: The children modules.
+ modules_as_dict: The children modules.
+ name: The object name.
+ mode: The object computing context/mode. Default is ``None``.
+ """
+
+ def __init__(
+ self,
+ *modules_as_tuple,
+ name: str = None,
+ mode: bm.Mode = None,
+ **modules_as_dict
+ ):
+ super().__init__(name=name, mode=mode)
+ self._dyn_modules = bm.NodeDict()
+ self._static_modules = dict()
+ i = 0
+ for m in modules_as_tuple + tuple(modules_as_dict.values()):
+ key = self.__format_key(i)
+ if isinstance(m, bm.BrainPyObject):
+ self._dyn_modules[key] = m
+ else:
+ self._static_modules[key] = m
+ i += 1
+ self._num = i
+
+ def update(self, x):
+ """Update function of a sequential model.
+ """
+ for m in self.__all_nodes():
+ x = m(x)
+ return x
+
+ def return_info(self):
+ last = self[-1]
+ if not isinstance(last, AutoDelaySupp):
+ raise UnsupportedError(f'Does not support "return_info()" because the last node is '
+ f'not instance of {AutoDelaySupp.__name__}')
+ return last.return_info()
+
+ def append(self, module: Callable):
+ assert isinstance(module, Callable)
+ key = self.__format_key(self._num)
+ if isinstance(module, bm.BrainPyObject):
+ self._dyn_modules[key] = module
+ else:
+ self._static_modules[key] = module
+ self._num += 1
+
+ def __format_key(self, i):
+ return f'l-{i}'
+
+ def __all_nodes(self):
+ nodes = []
+ for i in range(self._num):
+ key = self.__format_key(i)
+ if key not in self._dyn_modules:
+ nodes.append(self._static_modules[key])
+ else:
+ nodes.append(self._dyn_modules[key])
+ return nodes
+
+ def __getitem__(self, key: Union[int, slice, str]):
+ if isinstance(key, str):
+ if key in self._dyn_modules:
+ return self._dyn_modules[key]
+ elif key in self._static_modules:
+ return self._static_modules[key]
+ else:
+ raise KeyError(f'Does not find a component named {key} in\n {str(self)}')
+ elif isinstance(key, slice):
+ return Sequential(*(self.__all_nodes()[key]))
+ elif isinstance(key, int):
+ return self.__all_nodes()[key]
+ elif isinstance(key, (tuple, list)):
+ _all_nodes = self.__all_nodes()
+ return Sequential(*[_all_nodes[k] for k in key])
+ else:
+ raise KeyError(f'Unknown type of key: {type(key)}')
+
+ def __repr__(self):
+ nodes = self.__all_nodes()
+ entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(nodes))
+ return f'{self.__class__.__name__}(\n{entries}\n)'
-class System(Network):
- pass
-class NeuGroup(DynamicalSystem):
- """Base class to model neuronal groups.
+
+class Projection(DynamicalSystem):
+ def reset_state(self, *args, **kwargs):
+ pass
+
+
+class Dynamics(DynamicalSystem):
+ """Base class to model dynamics.
There are several essential attributes:
@@ -620,28 +526,21 @@ class NeuGroup(DynamicalSystem):
- ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
`num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
- Parameters
- ----------
- size : int, tuple of int, list of int
- The neuron group geometry.
- name : optional, str
- The name of the dynamic system.
- keep_size: bool
- Whether keep the geometry information.
-
- .. versionadded:: 2.1.13
- mode: Mode
- The computing mode.
-
- .. versionadded:: 2.2.0
+ Args:
+ size: The neuron group geometry.
+ name: The name of the dynamic system.
+ keep_size: Whether keep the geometry information.
+ mode: The computing mode.
"""
def __init__(
self,
size: Shape,
keep_size: bool = False,
- name: str = None,
- mode: bm.Mode = None,
+ sharding: Optional[Any] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ method: str = 'exp_auto'
):
# size
if isinstance(size, (list, tuple)):
@@ -663,17 +562,23 @@ def __init__(
# number of neurons
self.num = tools.size2num(size)
+ # axis names for parallelization
+ self.sharding = sharding
+
+ # integration method
+ self.method = method
+
+ # inputs
+ self.cur_inputs: Dict = bm.node_dict()
+
# initialize
- super(NeuGroup, self).__init__(name=name, mode=mode)
+ super().__init__(name=name, mode=mode)
@property
def varshape(self):
"""The shape of variables in the neuron group."""
return self.size if self.keep_size else (self.num,)
- def __repr__(self):
- return f'{self.__class__.__name__}(name={self.name}, mode={self.mode}, size={self.size})'
-
def get_batch_shape(self, batch_size=None):
if batch_size is None:
return self.varshape
@@ -686,540 +591,60 @@ def update(self, *args, **kwargs):
raise NotImplementedError(f'Subclass of {self.__class__.__name__} must '
f'implement "update" function.')
- def clear_input(self):
- """Function to clear inputs in the neuron group.
- It will be useful when monitoring inputs of the object received."""
- pass
+ def init_param(self, param, shape=None, sharding=None):
+ """Initialize parameters.
- def __getitem__(self, item):
- return NeuGroupView(target=self, index=item)
+ If ``sharding`` is provided and ``param`` is array, this function will
+ partition the parameter across the default device mesh.
-
-class SynConn(DynamicalSystem):
- """Base class to model two-end synaptic connections.
-
- Parameters
- ----------
- pre : NeuGroup
- Pre-synaptic neuron group.
- post : NeuGroup
- Post-synaptic neuron group.
- conn : optional, ndarray, ArrayType, dict, TwoEndConnector
- The connection method between pre- and post-synaptic groups.
- name : str, optional
- The name of the dynamic system.
- """
-
- def __init__(
- self,
- pre: NeuGroup,
- post: NeuGroup,
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]] = None,
- name: str = None,
- mode: bm.Mode = None,
- ):
- super(SynConn, self).__init__(name=name, mode=mode)
-
- # pre or post neuron group
- # ------------------------
- if not isinstance(pre, (NeuGroup, DynamicalSystem)):
- raise TypeError('"pre" must be an instance of NeuGroup.')
- if not isinstance(post, (NeuGroup, DynamicalSystem)):
- raise TypeError('"post" must be an instance of NeuGroup.')
- self.pre = pre
- self.post = post
-
- # connectivity
- # ------------
- if isinstance(conn, TwoEndConnector):
- self.conn = conn(pre.size, post.size)
- elif isinstance(conn, (bm.ndarray, np.ndarray, jax.Array)):
- if (pre.num, post.num) != conn.shape:
- raise ValueError(f'"conn" is provided as a matrix, and it is expected '
- f'to be an array with shape of (pre.num, post.num) = '
- f'{(pre.num, post.num)}, however we got {conn.shape}')
- self.conn = MatConn(conn_mat=conn)
- elif isinstance(conn, dict):
- if not ('i' in conn and 'j' in conn):
- raise ValueError(f'"conn" is provided as a dict, and it is expected to '
- f'be a dictionary with "i" and "j" specification, '
- f'however we got {conn}')
- self.conn = IJConn(i=conn['i'], j=conn['j'])
- elif isinstance(conn, str):
- self.conn = conn
- elif conn is None:
- self.conn = None
- else:
- raise ValueError(f'Unknown "conn" type: {conn}')
-
- def __repr__(self):
- names = self.__class__.__name__
- return (f'{names}(name={self.name}, mode={self.mode}, \n'
- f'{" " * len(names)} pre={self.pre}, \n'
- f'{" " * len(names)} post={self.post})')
-
- def check_pre_attrs(self, *attrs):
- """Check whether pre group satisfies the requirement."""
- if not hasattr(self, 'pre'):
- raise ValueError('Please call __init__ function first.')
- for attr in attrs:
- if not isinstance(attr, str):
- raise TypeError(f'Must be string. But got {attr}.')
- if not hasattr(self.pre, attr):
- raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".')
-
- def check_post_attrs(self, *attrs):
- """Check whether post group satisfies the requirement."""
- if not hasattr(self, 'post'):
- raise ValueError('Please call __init__ function first.')
- for attr in attrs:
- if not isinstance(attr, str):
- raise TypeError(f'Must be string. But got {attr}.')
- if not hasattr(self.post, attr):
- raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".')
-
- def update(self, *args, **kwargs):
- """The function to specify the updating rule.
-
- Assume any dynamical system depends on the shared variables (`sha`),
- like time variable ``t``, the step precision ``dt``, and the time step `i`.
+ See :py:func:`~.brainpy.math.sharding.device_mesh` for the mesh setting.
"""
- raise NotImplementedError('Must implement "update" function by subclass self.')
-
-
-class _SynComponent(DynamicalSystem):
- """Base class for modeling synaptic components,
- including synaptic output, synaptic short-term plasticity,
- synaptic long-term plasticity, and others. """
-
- '''Master of this component.'''
- master: SynConn
-
- def __init__(self, *args, **kwargs):
- super(_SynComponent, self).__init__(*args, **kwargs)
+ shape = self.varshape if shape is None else shape
+ sharding = self.sharding if sharding is None else sharding
+ return parameter(param,
+ sizes=shape,
+ allow_none=False,
+ sharding=sharding)
- self._registered = False
+ def init_variable(self, var_data, batch_or_mode, shape=None, sharding=None):
+ """Initialize variables.
- @property
- def isregistered(self) -> bool:
- """State of the component, representing whether it has been registered."""
- return self._registered
-
- @isregistered.setter
- def isregistered(self, val: bool):
- if not isinstance(val, bool):
- raise ValueError('Must be an instance of bool.')
- self._registered = val
-
- def reset_state(self, batch_size=None):
- pass
+ If ``sharding`` is provided and ``var_data`` is array, this function will
+ partition the variable across the default device mesh.
- def register_master(self, master: SynConn):
- if not isinstance(master, SynConn):
- raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}')
- if self.isregistered:
- raise ValueError(f'master has been registered, but we got another master going to be registered.')
- if hasattr(self, 'master') and self.master != master:
- raise ValueError(f'master has been registered, but we got another master going to be registered.')
- self.master = master
- self._registered = True
+ See :py:func:`~.brainpy.math.sharding.device_mesh` for the mesh setting.
+ """
+ shape = self.varshape if shape is None else shape
+ sharding = self.sharding if sharding is None else sharding
+ return variable_(var_data,
+ sizes=shape,
+ batch_or_mode=batch_or_mode,
+ axis_names=sharding,
+ batch_axis_name=bm.sharding.BATCH_AXIS)
def __repr__(self):
- return self.__class__.__name__
-
- def __call__(self, *args, **kwargs):
- return self.filter(*args, **kwargs)
-
- def clone(self) -> '_SynComponent':
- """The function useful to clone a new object when it has been used."""
- raise NotImplementedError
-
- def filter(self, g):
- raise NotImplementedError
-
-
-class SynOut(_SynComponent):
- """Base class for synaptic current output."""
-
- def __init__(
- self,
- name: str = None,
- target_var: Union[str, Variable] = None,
- ):
- super(SynOut, self).__init__(name=name)
- # check target variable
- if target_var is not None:
- if not isinstance(target_var, (str, Variable)):
- raise TypeError('"target_var" must be instance of string or Variable. '
- f'But we got {type(target_var)}')
- self.target_var: Optional[Variable] = target_var
-
- def register_master(self, master: SynConn):
- super(SynOut, self).register_master(master)
-
- # initialize target variable to output
- if isinstance(self.target_var, str):
- if not hasattr(self.master.post, self.target_var):
- raise KeyError(f'Post-synaptic group does not have target variable: {self.target_var}')
- self.target_var = getattr(self.master.post, self.target_var)
-
- def filter(self, g):
- if self.target_var is None:
- return g
- else:
- self.target_var += g
-
- def update(self, tdi):
- pass
-
-
-class SynSTP(_SynComponent):
- """Base class for synaptic short-term plasticity."""
-
- def update(self, tdi, pre_spike):
- pass
-
-
-class SynLTP(_SynComponent):
- """Base class for synaptic long-term plasticity."""
-
- def update(self, tdi, pre_spike):
- pass
-
-
-class NullSynOut(SynOut):
- def clone(self):
- return NullSynOut()
-
-
-class TwoEndConn(SynConn):
- """Base class to model synaptic connections.
-
- Parameters
- ----------
- pre : NeuGroup
- Pre-synaptic neuron group.
- post : NeuGroup
- Post-synaptic neuron group.
- conn : optional, ndarray, ArrayType, dict, TwoEndConnector
- The connection method between pre- and post-synaptic groups.
- output: Optional, SynOutput
- The output for the synaptic current.
-
- .. versionadded:: 2.1.13
- The output component for a two-end connection model.
-
- stp: Optional, SynSTP
- The short-term plasticity model for the synaptic variables.
-
- .. versionadded:: 2.1.13
- The short-term plasticity component for a two-end connection model.
-
- ltp: Optional, SynLTP
- The long-term plasticity model for the synaptic variables.
-
- .. versionadded:: 2.1.13
- The long-term plasticity component for a two-end connection model.
-
- name: Optional, str
- The name of the dynamic system.
- """
-
- def __init__(
- self,
- pre: NeuGroup,
- post: NeuGroup,
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]] = None,
- output: SynOut = NullSynOut(),
- stp: Optional[SynSTP] = None,
- ltp: Optional[SynLTP] = None,
- mode: bm.Mode = None,
- name: str = None,
- ):
- super(TwoEndConn, self).__init__(pre=pre,
- post=post,
- conn=conn,
- name=name,
- mode=mode)
-
- # synaptic output
- output = NullSynOut() if output is None else output
- if output.isregistered: output = output.clone()
- if not isinstance(output, SynOut):
- raise TypeError(f'output must be instance of {SynOut.__name__}, '
- f'but we got {type(output)}')
- output.register_master(master=self)
- self.output: SynOut = output
-
- # short-term synaptic plasticity
- if stp is not None:
- if stp.isregistered: stp = stp.clone()
- if not isinstance(stp, SynSTP):
- raise TypeError(f'Short-term plasticity must be instance of {SynSTP.__name__}, '
- f'but we got {type(stp)}')
- stp.register_master(master=self)
- self.stp: SynSTP = stp
-
- # long-term synaptic plasticity
- if ltp is not None:
- if ltp.isregistered: ltp = ltp.clone()
- if not isinstance(ltp, SynLTP):
- raise TypeError(f'Long-term plasticity must be instance of {SynLTP.__name__}, '
- f'but we got {type(ltp)}')
- ltp.register_master(master=self)
- self.ltp: SynLTP = ltp
-
- def _init_weights(
- self,
- weight: Union[float, ArrayType, Initializer, Callable],
- comp_method: str,
- sparse_data: str = 'csr'
- ) -> Tuple[Union[float, ArrayType], ArrayType]:
- if comp_method not in ['sparse', 'dense']:
- raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}')
- if sparse_data not in ['csr', 'ij', 'coo']:
- raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {sparse_data}')
- if self.conn is None:
- raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
-
- # connections and weights
- if isinstance(self.conn, One2One):
- weight = parameter(weight, (self.pre.num,), allow_none=False)
- conn_mask = None
-
- elif isinstance(self.conn, All2All):
- weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False)
- conn_mask = None
-
- else:
- if comp_method == 'sparse':
- if sparse_data == 'csr':
- conn_mask = self.conn.require('pre2post')
- elif sparse_data in ['ij', 'coo']:
- conn_mask = self.conn.require('post_ids', 'pre_ids')
- else:
- ValueError(f'Unknown sparse data type: {sparse_data}')
- weight = parameter(weight, conn_mask[0].shape, allow_none=False)
- elif comp_method == 'dense':
- weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False)
- conn_mask = self.conn.require('conn_mat')
- else:
- raise ValueError(f'Unknown connection type: {comp_method}')
-
- # training weights
- if isinstance(self.mode, bm.TrainingMode):
- weight = bm.TrainVar(weight)
- return weight, conn_mask
-
- def _syn2post_with_all2all(self, syn_value, syn_weight):
- if bm.ndim(syn_weight) == 0:
- if isinstance(self.mode, bm.BatchingMode):
- post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:])
- else:
- post_vs = bm.sum(syn_value)
- if not self.conn.include_self:
- post_vs = post_vs - syn_value
- post_vs = syn_weight * post_vs
- else:
- post_vs = syn_value @ syn_weight
- return post_vs
-
- def _syn2post_with_one2one(self, syn_value, syn_weight):
- return syn_value * syn_weight
-
- def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
- if bm.ndim(syn_weight) == 0:
- post_vs = (syn_weight * syn_value) @ conn_mat
- else:
- post_vs = syn_value @ (syn_weight * conn_mat)
- return post_vs
-
-
-class TwoEndConnNS(TwoEndConn):
- """Two-end connection without passing shared arguments."""
- _pass_shared_args = False
-
-
-class CondNeuGroup(NeuGroup, Container):
- r"""Base class to model conductance-based neuron group.
-
- The standard formulation for a conductance-based model is given as
-
- .. math::
-
- C_m {dV \over dt} = \sum_jg_j(E - V) + I_{ext}
-
- where :math:`g_j=\bar{g}_{j} M^x N^y` is the channel conductance, :math:`E` is the
- reversal potential, :math:`M` is the activation variable, and :math:`N` is the
- inactivation variable.
-
- :math:`M` and :math:`N` have the dynamics of
-
- .. math::
-
- {dx \over dt} = \phi_x {x_\infty (V) - x \over \tau_x(V)}
-
- where :math:`x \in [M, N]`, :math:`\phi_x` is a temperature-dependent factor,
- :math:`x_\infty` is the steady state, and :math:`\tau_x` is the time constant.
- Equivalently, the above equation can be written as:
-
- .. math::
-
- \frac{d x}{d t}=\phi_{x}\left(\alpha_{x}(1-x)-\beta_{x} x\right)
-
- where :math:`\alpha_{x}` and :math:`\beta_{x}` are rate constants.
-
- .. versionadded:: 2.1.9
- Model the conductance-based neuron model.
-
- Parameters
- ----------
- size : int, sequence of int
- The network size of this neuron group.
- method: str
- The numerical integration method.
- name : optional, str
- The neuron group name.
-
- See Also
- --------
- Channel
-
- """
-
- def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- C: Union[float, ArrayType, Initializer, Callable] = 1.,
- A: Union[float, ArrayType, Initializer, Callable] = 1e-3,
- V_th: Union[float, ArrayType, Initializer, Callable] = 0.,
- V_initializer: Union[Initializer, Callable, ArrayType] = Uniform(-70, -60.),
- noise: Union[float, ArrayType, Initializer, Callable] = None,
- method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
- **channels
- ):
- NeuGroup.__init__(self, size, keep_size=keep_size, mode=mode)
- Container.__init__(self, **channels, name=name, mode=mode)
-
- # parameters for neurons
- self.C = C
- self.A = A
- self.V_th = V_th
- self.noise = init_noise(noise, self.varshape, num_vars=3)
- self._V_initializer = V_initializer
-
- # variables
- self.V = variable(V_initializer, self.mode, self.varshape)
- self.input = variable(bm.zeros, self.mode, self.varshape)
- self.spike = variable(lambda s: bm.zeros(s, dtype=bool), self.mode, self.varshape)
-
- # function
- if self.noise is None:
- self.integral = odeint(f=self.derivative, method=method)
- else:
- self.integral = sdeint(f=self.derivative, g=self.noise, method=method)
-
- def derivative(self, V, t):
- Iext = self.input.value * (1e-3 / self.A)
- channels = self.nodes(level=1, include_self=False).subset(Channel).unique()
- for ch in channels.values():
- Iext = Iext + ch.current(V)
- return Iext / self.C
-
- def reset_state(self, batch_size=None):
- self.V.value = variable(self._V_initializer, batch_size, self.varshape)
- self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
- self.input.value = variable(bm.zeros, batch_size, self.varshape)
- for channel in self.nodes(level=1, include_self=False).subset(Channel).unique().values():
- channel.reset_state(self.V.value, batch_size=batch_size)
-
- def update(self, tdi, *args, **kwargs):
- V = self.integral(self.V.value, tdi['t'], tdi['dt'])
-
- channels = self.nodes(level=1, include_self=False).subset(Channel).unique()
- # check whether the children channels have the correct parents.
- check_master(type(self), **channels)
-
- # update variables
- for node in channels.values():
- node.update(tdi, self.V.value)
- self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th)
- self.V.value = V
-
- def register_implicit_nodes(self, *channels, **named_channels):
- check_master(type(self), *channels, **named_channels)
- super(CondNeuGroup, self).register_implicit_nodes(*channels, **named_channels)
-
- def clear_input(self):
- """Useful for monitoring inputs. """
- self.input.value = bm.zeros_like(self.input.value)
-
-
-class Channel(DynamicalSystem):
- """Abstract channel class."""
-
- master_type = CondNeuGroup
-
- def __init__(
- self,
- size: Union[int, Sequence[int]],
- name: str = None,
- keep_size: bool = False,
- mode: bm.Mode = None,
- ):
- super(Channel, self).__init__(name=name, mode=mode)
- # the geometry size
- self.size = tools.to_size(size)
- # the number of elements
- self.num = tools.size2num(self.size)
- # variable shape
- self.keep_size = keep_size
-
- @property
- def varshape(self):
- return self.size if self.keep_size else self.num
+ return f'{self.__class__.__name__}(name={self.name}, mode={self.mode}, size={self.size})'
- def update(self, tdi, V):
- raise NotImplementedError('Must be implemented by the subclass.')
+ def __getitem__(self, item):
+ return NeuDynView(target=self, index=item)
- def current(self, V):
- raise NotImplementedError('Must be implemented by the subclass.')
- def reset_state(self, V, batch_size=None):
- raise NotImplementedError('Must be implemented by the subclass.')
+class NeuDyn(Dynamics, AutoDelaySupp):
+ """Neuronal Dynamics."""
+ pass
-def _check(master, child):
- if not hasattr(child, 'master_type'):
- raise ValueError('Child class should define "master_type" to specify the type of the master. '
- f'But we did not found it in {child}')
- if not issubclass(master, child.master_type):
- raise TypeError(f'Type does not match. {child} requires a master with type '
- f'of {child.master_type}, but the master now is {master}.')
+class SynDyn(Dynamics, AutoDelaySupp, ParamDesc):
+ """Synaptic Dynamics."""
+ pass
-def check_master(master, *channels, **named_channels):
- for channel in channels:
- if isinstance(channel, Channel):
- _check(master, channel)
- elif isinstance(channel, (list, tuple)):
- for ch in channel:
- _check(master, ch)
- elif isinstance(channel, dict):
- for ch in channel.values():
- _check(master, ch)
- else:
- raise ValueError(f'Do not support {type(channel)}.')
- for channel in named_channels.values():
- if not isinstance(channel, Channel):
- raise ValueError(f'Do not support {type(channel)}. ')
- _check(master, channel)
+class IonChaDyn(Dynamics):
+ """Ion Channel Dynamics."""
+ pass
-class DSView(DynamicalSystem):
+class DynView(Dynamics):
"""DSView, an object used to get a view of a dynamical system instance.
It can get a subset view of variables in a dynamical system instance.
@@ -1227,14 +652,14 @@ class DSView(DynamicalSystem):
>>> import brainpy as bp
>>> hh = bp.neurons.HH(10)
- >>> DSView(hh, slice(5, 10, None))
+ >>> DynView(hh, slice(5, 10, None))
>>> # or, simply
>>> hh[5:]
"""
def __init__(
self,
- target: DynamicalSystem,
+ target: Dynamics,
index: Union[slice, Sequence, ArrayType],
varshape: Tuple[int, ...] = None,
name: str = None,
@@ -1256,7 +681,7 @@ def __init__(
# get all variables for slicing
if not hasattr(self.target, SLICE_VARS):
if varshape is None:
- if isinstance(target, NeuGroup):
+ if isinstance(target, NeuDyn):
varshape = target.varshape
else:
raise UnsupportedError('Should provide varshape when the target does '
@@ -1282,15 +707,15 @@ def __init__(
for _ in range(v.batch_axis - len(self.index) + 1)])))
else:
index = self.index
- self.slice_vars[k] = VariableView(v, index)
+ self.slice_vars[k] = bm.VariableView(v, index)
# sub-nodes
nodes = target.nodes(method='relative', level=1, include_self=False).subset(DynamicalSystem)
for k, node in nodes.items():
- if isinstance(node, NeuGroup):
- node = NeuGroupView(node, self.index)
+ if isinstance(node, NeuDyn):
+ node = NeuDynView(node, self.index)
else:
- node = DSView(node, self.index, varshape)
+ node = DynView(node, self.index, varshape)
setattr(self, k, node)
def __repr__(self):
@@ -1308,12 +733,12 @@ def __getattribute__(self, item):
def __setattr__(self, key, value):
if hasattr(self, 'slice_vars'):
- slice_vars = super(DSView, self).__getattribute__('slice_vars')
+ slice_vars = super(DynView, self).__getattribute__('slice_vars')
if key in slice_vars:
v = slice_vars[key]
v.value = value
return
- super(DSView, self).__setattr__(key, value)
+ super(DynView, self).__setattr__(key, value)
def update(self, *args, **kwargs):
raise NoImplementationError(f'DSView {self} cannot be updated. Please update its parent {self.target}')
@@ -1350,17 +775,17 @@ def _slice_to_num(slice_: slice, length: int):
return num
-class NeuGroupView(DSView, NeuGroup):
+class NeuDynView(DynView, NeuDyn):
"""A view for a neuron group instance."""
def __init__(
self,
- target: NeuGroup,
+ target: NeuDyn,
index: Union[slice, Sequence, ArrayType],
name: str = None,
mode: bm.Mode = None
):
- DSView.__init__(self, target, index)
+ DynView.__init__(self, target, index)
# check slicing
var_shapes = target.varshape
@@ -1385,129 +810,4 @@ def __init__(
size += list(var_shapes[len(self.index):])
# initialization
- NeuGroup.__init__(self, tuple(size), name=name, mode=mode)
-
-
-class DynamicalSystemNS(DynamicalSystem):
- """Dynamical system without the need to pass shared parameters into ``update()`` function."""
-
- _pass_shared_args = False
-
-
-class Sequential(DynamicalSystemNS):
- """A sequential `input-output` module.
-
- Modules will be added to it in the order they are passed in the
- constructor. Alternatively, an ``dict`` of modules can be
- passed in. The ``update()`` method of ``Sequential`` accepts any
- input and forwards it to the first module it contains. It then
- "chains" outputs to inputs sequentially for each subsequent module,
- finally returning the output of the last module.
-
- The value a ``Sequential`` provides over manually calling a sequence
- of modules is that it allows treating the whole container as a
- single module, such that performing a transformation on the
- ``Sequential`` applies to each of the modules it stores (which are
- each a registered submodule of the ``Sequential``).
-
- What's the difference between a ``Sequential`` and a
- :py:class:`Container`? A ``Container`` is exactly what it
- sounds like--a container to store :py:class:`DynamicalSystem` s!
- On the other hand, the layers in a ``Sequential`` are connected
- in a cascading way.
-
- Examples
- --------
-
- >>> import brainpy as bp
- >>> import brainpy.math as bm
- >>>
- >>> # composing ANN models
- >>> l = bp.Sequential(bp.layers.Dense(100, 10),
- >>> bm.relu,
- >>> bp.layers.Dense(10, 2))
- >>> l({}, bm.random.random((256, 100)))
- >>>
- >>> # Using Sequential with Dict. This is functionally the
- >>> # same as the above code
- >>> l = bp.Sequential(l1=bp.layers.Dense(100, 10),
- >>> l2=bm.relu,
- >>> l3=bp.layers.Dense(10, 2))
- >>> l({}, bm.random.random((256, 100)))
-
- Parameters
- ----------
- name: str
- The object name.
- mode: Mode
- The object computing context/mode. Default is ``None``.
- """
-
- def __init__(
- self,
- *modules_as_tuple,
- name: str = None,
- mode: bm.Mode = None,
- **modules_as_dict
- ):
- super().__init__(name=name, mode=mode)
- self._dyn_modules = bm.NodeDict()
- self._static_modules = dict()
- i = 0
- for m in modules_as_tuple + tuple(modules_as_dict.values()):
- key = self.__format_key(i)
- if isinstance(m, bm.BrainPyObject):
- self._dyn_modules[key] = m
- else:
- self._static_modules[key] = m
- i += 1
- self._num = i
-
- def __format_key(self, i):
- return f'l-{i}'
-
- def __all_nodes(self):
- nodes = []
- for i in range(self._num):
- key = self.__format_key(i)
- if key not in self._dyn_modules:
- nodes.append(self._static_modules[key])
- else:
- nodes.append(self._dyn_modules[key])
- return nodes
-
- def __getitem__(self, key: Union[int, slice, str]):
- if isinstance(key, str):
- if key in self._dyn_modules:
- return self._dyn_modules[key]
- elif key in self._static_modules:
- return self._static_modules[key]
- else:
- raise KeyError(f'Does not find a component named {key} in\n {str(self)}')
- elif isinstance(key, slice):
- return Sequential(*(self.__all_nodes()[key]))
- elif isinstance(key, int):
- return self.__all_nodes()[key]
- elif isinstance(key, (tuple, list)):
- _all_nodes = self.__all_nodes()
- return Sequential(*[_all_nodes[k] for k in key])
- else:
- raise KeyError(f'Unknown type of key: {type(key)}')
-
- def __repr__(self):
- nodes = self.__all_nodes()
- entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(nodes))
- return f'{self.__class__.__name__}(\n{entries}\n)'
-
- def update(self, x):
- """Update function of a sequential model.
- """
- for m in self.__all_nodes():
- x = m(x)
- return x
-
-
-class NeuGroupNS(NeuGroup):
- """Base class for neuron group without shared arguments passed."""
- _pass_shared_args = False
-
+ NeuDyn.__init__(self, tuple(size), name=name, mode=mode)
diff --git a/brainpy/_src/integrators/ode/exponential.py b/brainpy/_src/integrators/ode/exponential.py
index e4b57ff46..9d1b1adcf 100644
--- a/brainpy/_src/integrators/ode/exponential.py
+++ b/brainpy/_src/integrators/ode/exponential.py
@@ -138,7 +138,7 @@ class ExponentialEuler(ODEIntegrator):
>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
- >>> class HH(bp.NeuGroup):
+ >>> class HH(bp.NeuDyn):
>>> def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9.,
>>> gL=0.1, V_th=20., phi=5.0, name=None):
>>> super(HH, self).__init__(size=size, name=name)
@@ -211,7 +211,7 @@ class ExponentialEuler(ODEIntegrator):
>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
- >>> class HH(bp.NeuGroup):
+ >>> class HH(bp.NeuDyn):
>>> def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9.,
>>> gL=0.1, V_th=20., phi=5.0, name=None):
>>> super(HH, self).__init__(size=size, name=name)
diff --git a/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py b/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py
index d950c509c..46654c4a0 100644
--- a/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py
+++ b/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py
@@ -46,7 +46,7 @@ def dev(x, t):
class TestExpEulerAuto(unittest.TestCase):
def test_hh_model(self):
- class HH(bp.NeuGroup):
+ class HH(bp.NeuDyn):
def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9.,
gL=0.1, V_th=20., phi=5.0, name=None, method='exponential_euler'):
super(HH, self).__init__(size=size, name=name)
diff --git a/brainpy/_src/math/compat_numpy.py b/brainpy/_src/math/compat_numpy.py
index dcb2688e0..d8da11c9e 100644
--- a/brainpy/_src/math/compat_numpy.py
+++ b/brainpy/_src/math/compat_numpy.py
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
+
import jax
import jax.numpy as jnp
import numpy as np
diff --git a/brainpy/_src/math/compat_pytorch.py b/brainpy/_src/math/compat_pytorch.py
index 70031a17a..419f2d146 100644
--- a/brainpy/_src/math/compat_pytorch.py
+++ b/brainpy/_src/math/compat_pytorch.py
@@ -34,7 +34,6 @@
]
-
Tensor = Array
cat = concatenate
diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py
index b5b0c7f08..1e28fb232 100644
--- a/brainpy/_src/math/delayvars.py
+++ b/brainpy/_src/math/delayvars.py
@@ -341,6 +341,10 @@ def __init__(
self.num_delay_step: int = 0
self.idx: Variable = None
+ self.delay_target = None
+ if isinstance(delay_target, Variable):
+ self.delay_target = delay_target
+
# initialization
self.reset(delay_target, delay_len, initial_delay_data, batch_axis)
@@ -448,7 +452,7 @@ def retrieve(self, delay_len, *indices):
# the delay data
return self.data[indices]
- def update(self, value: Union[numbers.Number, Array, jax.Array]):
+ def update(self, value: Union[numbers.Number, Array, jax.Array] = None):
"""Update delay variable with the new data.
Parameters
@@ -456,6 +460,12 @@ def update(self, value: Union[numbers.Number, Array, jax.Array]):
value: Any
The value of the latest data, used to update this delay variable.
"""
+ if value is None:
+ if self.delay_target is None:
+ raise ValueError('Must provide value.')
+ else:
+ value = self.delay_target.value
+
if self.update_method == ROTATE_UPDATE:
self.idx.value = stop_gradient(as_jax((self.idx - 1) % self.num_delay_step))
self.data[self.idx[0]] = value
diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py
index 95f1d6ecb..fe997846d 100644
--- a/brainpy/_src/math/ndarray.py
+++ b/brainpy/_src/math/ndarray.py
@@ -87,6 +87,14 @@ def _check_tracer(self):
'Please declare it as a Variable.') from jax.core.escaped_tracer_error(self_value, None)
return self_value
+ @property
+ def sharding(self):
+ return self._value.sharding
+
+ @property
+ def addressable_shards(self):
+ return self._value.addressable_shards
+
@property
def value(self):
return self._value
diff --git a/brainpy/_src/math/object_transform/tests/test_base.py b/brainpy/_src/math/object_transform/tests/test_base.py
index b3762865d..9435cf56f 100644
--- a/brainpy/_src/math/object_transform/tests/test_base.py
+++ b/brainpy/_src/math/object_transform/tests/test_base.py
@@ -81,7 +81,7 @@ def __init__(self):
class TestNodeList(bp.testing.UnitTestCase):
def test_NodeList_1(self):
- class Object(bp.DynamicalSystemNS):
+ class Object(bp.DynamicalSystem):
def __init__(self):
super().__init__()
@@ -121,7 +121,7 @@ def update(self, x):
class TestNodeDict(bp.testing.UnitTestCase):
def test_NodeDict_1(self):
- class Object(bp.DynamicalSystemNS):
+ class Object(bp.DynamicalSystem):
def __init__(self):
super().__init__()
@@ -167,7 +167,7 @@ def update(self, x):
class TestVarList(bp.testing.UnitTestCase):
def test_ListVar_1(self):
- class Object(bp.DynamicalSystemNS):
+ class Object(bp.DynamicalSystem):
def __init__(self):
super().__init__()
self.vs = bm.VarList([bm.Variable(1.),
@@ -196,7 +196,7 @@ def f2():
class TestVarDict(bp.testing.UnitTestCase):
def test_DictVar_1(self):
- class Object(bp.DynamicalSystemNS):
+ class Object(bp.DynamicalSystem):
def __init__(self):
super().__init__()
self.vs = bm.VarDict({'a': bm.Variable(1.),
diff --git a/brainpy/_src/math/object_transform/tests/test_circular_reference.py b/brainpy/_src/math/object_transform/tests/test_circular_reference.py
index 8e66f7afd..2dc076ff4 100644
--- a/brainpy/_src/math/object_transform/tests/test_circular_reference.py
+++ b/brainpy/_src/math/object_transform/tests/test_circular_reference.py
@@ -5,7 +5,7 @@
import brainpy as bp
-class HH(bp.NeuGroup):
+class HH(bp.NeuDyn):
def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0,
gNa=35., gK=9., gL=0.1, V_th=20., phi=5.0, **kwargs):
super(HH, self).__init__(size=size, **kwargs)
diff --git a/brainpy/_src/math/object_transform/tests/test_collector.py b/brainpy/_src/math/object_transform/tests/test_collector.py
index 142f779b3..f5b7fb0d0 100644
--- a/brainpy/_src/math/object_transform/tests/test_collector.py
+++ b/brainpy/_src/math/object_transform/tests/test_collector.py
@@ -40,7 +40,7 @@ def update(self, tdi):
self.post.inputs -= jnp.sum(self.s, axis=0) * (self.post.V - self.E)
-class HH_without_Variable(bp.NeuGroup):
+class HH_without_Variable(bp.NeuDyn):
def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0,
gNa=35., gK=9., gL=0.1, V_th=20., phi=5.0, **kwargs):
super(HH_without_Variable, self).__init__(size=size, **kwargs)
@@ -117,7 +117,7 @@ def test_neu_vars_1():
assert len(vars) == 0
-class HH_with_Variable(bp.NeuGroup):
+class HH_with_Variable(bp.NeuDyn):
def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0,
gNa=35., gK=9., gL=0.1, V_th=20., phi=5.0, **kwargs):
super(HH_with_Variable, self).__init__(size=size, **kwargs)
diff --git a/brainpy/_src/math/object_transform/tests/test_namechecking.py b/brainpy/_src/math/object_transform/tests/test_namechecking.py
index a8404e03d..c008cd4a9 100644
--- a/brainpy/_src/math/object_transform/tests/test_namechecking.py
+++ b/brainpy/_src/math/object_transform/tests/test_namechecking.py
@@ -4,7 +4,7 @@
import brainpy as bp
-class LIF(bp.NeuGroup):
+class LIF(bp.NeuDyn):
pass
diff --git a/brainpy/_src/math/object_transform/tests/test_tools.py b/brainpy/_src/math/object_transform/tests/test_tools.py
index e5a897f79..69781d624 100644
--- a/brainpy/_src/math/object_transform/tests/test_tools.py
+++ b/brainpy/_src/math/object_transform/tests/test_tools.py
@@ -92,7 +92,7 @@ def f2():
def test_cache3(self):
call_num = [0]
- class Model(bp.DynamicalSystemNS):
+ class Model(bp.DynamicalSystem):
def __init__(self):
super().__init__()
self.a = bm.Variable(bm.ones(1))
diff --git a/brainpy/_src/math/sharding.py b/brainpy/_src/math/sharding.py
index cb2faeed3..7ab697742 100644
--- a/brainpy/_src/math/sharding.py
+++ b/brainpy/_src/math/sharding.py
@@ -131,8 +131,11 @@ def partition_by_sharding(
return x
else:
assert isinstance(sharding, Sharding)
- f = partial(_device_put, device=sharding)
- return jax.tree_util.tree_map(f, x, is_leaf=lambda a: isinstance(a, Array))
+ if isinstance(x, (Array, jax.Array)):
+ return _device_put(x, device=sharding)
+ return jax.tree_util.tree_map(partial(_device_put, device=sharding),
+ x,
+ is_leaf=lambda a: isinstance(a, Array))
def partition(
@@ -142,7 +145,11 @@ def partition(
if sharding is None:
return x
elif isinstance(sharding, (jax.Device, Sharding)):
- return jax.tree_util.tree_map(partial(_device_put, device=sharding), x, is_leaf=lambda a: isinstance(a, Array))
+ if isinstance(x, (Array, jax.Array)):
+ return _device_put(x, device=sharding)
+ return jax.tree_util.tree_map(partial(_device_put, device=sharding),
+ x,
+ is_leaf=lambda a: isinstance(a, Array))
elif isinstance(sharding, (tuple, list)) and any([isinstance(s, str) for s in sharding]):
return partition_by_axname(x, sharding)
else:
diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py
index b8ced2648..0718b06e4 100644
--- a/brainpy/_src/mixin.py
+++ b/brainpy/_src/mixin.py
@@ -1,71 +1,100 @@
-from typing import Optional, Sequence, Union, Tuple, Callable
+import numbers
from dataclasses import dataclass
-from brainpy import tools, math as bm
+from typing import Union, Dict, Callable, Sequence, Optional, TypeVar
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+from brainpy import math as bm, tools
+from brainpy._src.initialize import parameter
+from brainpy._src.typing_copy import _SpecialForm, _UnionGenericAlias, _type_check, _remove_dups_flatten
+from brainpy.types import ArrayType
+
+DynamicalSystem = None
__all__ = [
'MixIn',
'ParamDesc',
'AlignPost',
- 'ProjAutoDelay',
+ 'AutoDelaySupp',
+ 'NoSH',
+ 'Container',
+ 'TreeNode',
+ 'BindCondData',
+ 'JointType',
]
+global_delay_data = dict()
+
class MixIn(object):
+ """Base MixIn object."""
pass
-class DelayedInit(object):
- """Delayed initialization.
- """
+class ParamDesc(MixIn):
+ """:py:class:`~.MixIn` indicates the function for describing initialization parameters.
- def __init__(
- self,
- cls: type,
- identifier,
- *args,
- **kwargs
- ):
- self.cls = cls
- self.args = args
- self.kwargs = kwargs
- self._identifier = identifier
+ This mixin enables the subclass has a classmethod ``desc``, which
+ produces an instance of :py:class:`~.ParamDescInit`.
- def __call__(self, *args, **kwargs):
- return self.cls(*self.args, *args, **self.kwargs, **kwargs)
+ Note this MixIn can be applied in any Python object.
+ """
- def init(self, *args, **kwargs):
- return self.__call__(*args, **kwargs)
+ not_desc_params: Optional[Sequence[str]] = None
@classmethod
- def __class_getitem__(cls, item):
- return cls
+ def desc(cls, *args, **kwargs) -> 'ParamDescInit':
+ return ParamDescInit(cls, *args, **kwargs)
-class ParamDesc(MixIn):
- """Parameter description MixIn.
-
- This mixin enables the subclass has a classmethod ``desc``, which
- produces an instance of :py:class:`~.DelayedInit`.
+class ParamDescInit(object):
+ """Delayed initialization for parameter describers.
"""
- not_desc_params: Optional[Sequence[str]] = None
+ def __init__(self, cls: type, *args, **kwargs):
+ self.cls = cls
- @classmethod
- def desc(cls, *args, **kwargs) -> DelayedInit:
- # cls_args = list(inspect.signature(cls.__init__).parameters.values())[1:]
- # names = [arg.name for arg in cls_args]
- # defaults = [arg.default for arg in cls_args]
- if cls.not_desc_params is not None:
- repr_kwargs = {k: v for k, v in kwargs.items() if k not in cls.not_desc_params}
- else:
+ # arguments
+ self.args = args
+ self.kwargs = kwargs
+
+ # identifier
+ if isinstance(cls, _JointGenericAlias):
+ name = str(cls)
repr_kwargs = {k: v for k, v in kwargs.items()}
+ else:
+ assert isinstance(cls, type)
+ if issubclass(cls, ParamDesc) and (cls.not_desc_params is not None):
+ repr_kwargs = {k: v for k, v in kwargs.items() if k not in cls.not_desc_params}
+ else:
+ repr_kwargs = {k: v for k, v in kwargs.items()}
+ name = cls.__name__
for k in tuple(repr_kwargs.keys()):
if isinstance(repr_kwargs[k], bm.Variable):
repr_kwargs[k] = id(repr_kwargs[k])
repr_args = tools.repr_dict(repr_kwargs)
if len(args):
repr_args = f"{', '.join([repr(arg) for arg in args])}, {repr_args}"
- return DelayedInit(cls, f'{cls.__name__}({repr_args})', *args, **kwargs)
+ self._identifier = f'{name}({repr_args})'
+
+ def __call__(self, *args, **kwargs):
+ return self.cls(*self.args, *args, **self.kwargs, **kwargs)
+
+ def init(self, *args, **kwargs):
+ return self.__call__(*args, **kwargs)
+
+ def __instancecheck__(self, instance):
+ if not isinstance(instance, ParamDescInit):
+ return False
+ if not issubclass(instance.cls, self.cls):
+ return False
+ return True
+
+ @classmethod
+ def __class_getitem__(cls, item: type):
+ return ParamDescInit(item)
class AlignPost(MixIn):
@@ -82,14 +111,400 @@ def add_current(self, *args, **kwargs):
@dataclass
class ReturnInfo:
size: Sequence[int]
- axis_names: Sequence[str]
+ axis_names: Optional[Sequence[str]]
batch_or_mode: Optional[Union[int, bm.Mode]]
init: Callable
-class ProjAutoDelay(MixIn):
- """Support for automatic delay in synaptic projection :py:class:`~.SynProj`."""
+class AutoDelaySupp(MixIn):
+ """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`."""
def return_info(self) -> Union[bm.Variable, ReturnInfo]:
- raise NotImplementedError
+ raise NotImplementedError('Must implement the "return_info()" function.')
+
+
+class NoSH(MixIn):
+ """``MixIn`` to indicate that no shared parameters should be passed into the ``update()`` function."""
+
+ def __init__(self, *args, **kwargs):
+ self._pass_shared_args = False
+
+
+class Container(MixIn):
+ """Container :py:class:`~.MixIn` which wrap a group of objects.
+ """
+ children: bm.node_dict
+
+ def __getitem__(self, item):
+ """Overwrite the slice access (`self['']`). """
+ if item in self.children:
+ return self.children[item]
+ else:
+ raise ValueError(f'Unknown item {item}, we only found {list(self.children.keys())}')
+
+ def __getattr__(self, item):
+ """Overwrite the dot access (`self.`). """
+ if item == 'children':
+ return super().__getattribute__('children')
+ else:
+ children = super().__getattribute__('children')
+ if item in children:
+ return children[item]
+ else:
+ return super().__getattribute__(item)
+
+ def __repr__(self):
+ cls_name = self.__class__.__name__
+ indent = ' ' * len(cls_name)
+ child_str = [tools.repr_context(repr(val), indent) for val in self.children.values()]
+ string = ", \n".join(child_str)
+ return f'{cls_name}({string})'
+
+ def format_elements(self, child_type: type, *children_as_tuple, **children_as_dict):
+ res = dict()
+
+ # add tuple-typed components
+ for module in children_as_tuple:
+ if isinstance(module, child_type):
+ res[module.name] = module
+ elif isinstance(module, (list, tuple)):
+ for m in module:
+ if not isinstance(m, child_type):
+ raise ValueError(f'Should be instance of {child_type.__name__}. '
+ f'But we got {type(m)}')
+ res[m.name] = m
+ elif isinstance(module, dict):
+ for k, v in module.items():
+ if not isinstance(v, child_type):
+ raise ValueError(f'Should be instance of {child_type.__name__}. '
+ f'But we got {type(v)}')
+ res[k] = v
+ else:
+ raise ValueError(f'Cannot parse sub-systems. They should be {child_type.__name__} '
+ f'or a list/tuple/dict of {child_type.__name__}.')
+ # add dict-typed components
+ for k, v in children_as_dict.items():
+ if not isinstance(v, child_type):
+ raise ValueError(f'Should be instance of {child_type.__name__}. '
+ f'But we got {type(v)}')
+ res[k] = v
+ return res
+
+
+class TreeNode(MixIn):
+ """Tree node. """
+
+ master_type: type
+
+ @staticmethod
+ def check_hierarchies(root, *leaves, **named_leaves):
+ global DynamicalSystem
+ if DynamicalSystem is None:
+ from brainpy._src.dynsys import DynamicalSystem
+
+ for leaf in leaves:
+ if isinstance(leaf, DynamicalSystem):
+ TreeNode.check_hierarchy(root, leaf)
+ elif isinstance(leaf, (list, tuple)):
+ TreeNode.check_hierarchies(root, *leaf)
+ elif isinstance(leaf, dict):
+ TreeNode.check_hierarchies(root, **leaf)
+ else:
+ raise ValueError(f'Do not support {type(leaf)}.')
+ for leaf in named_leaves.values():
+ if not isinstance(leaf, DynamicalSystem):
+ raise ValueError(f'Do not support {type(leaf)}. Must be instance of {DynamicalSystem.__name__}')
+ TreeNode.check_hierarchy(root, leaf)
+
+ @staticmethod
+ def check_hierarchy(root, leaf):
+ if hasattr(leaf, 'master_type'):
+ master_type = leaf.master_type
+ else:
+ raise ValueError('Child class should define "root_type" to '
+ 'specify the type of the root node. '
+ f'But we did not found it in {leaf}')
+ if not issubclass(root, master_type):
+ raise TypeError(f'Type does not match. {leaf} requires a master with type '
+ f'of {leaf.master_type}, but the master now is {leaf}.')
+
+
+class DelayRegister(MixIn):
+ local_delay_vars: bm.node_dict
+
+ def register_delay(
+ self,
+ identifier: str,
+ delay_step: Optional[Union[int, ArrayType, Callable]],
+ delay_target: bm.Variable,
+ initial_delay_data: Union[Callable, ArrayType, numbers.Number] = None,
+ ):
+ """Register delay variable.
+
+ Parameters
+ ----------
+ identifier: str
+ The delay variable name.
+ delay_step: Optional, int, ArrayType, callable, Initializer
+ The number of the steps of the delay.
+ delay_target: Variable
+ The target variable for delay.
+ initial_delay_data: float, int, ArrayType, callable, Initializer
+ The initializer for the delay data.
+
+ Returns
+ -------
+ delay_step: int, ArrayType
+ The number of the delay steps.
+ """
+ # delay steps
+ if delay_step is None:
+ delay_type = 'none'
+ elif isinstance(delay_step, (int, np.integer, jnp.integer)):
+ delay_type = 'homo'
+ elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)):
+ if delay_step.size == 1 and delay_step.ndim == 0:
+ delay_type = 'homo'
+ else:
+ delay_type = 'heter'
+ delay_step = bm.asarray(delay_step)
+ elif callable(delay_step):
+ delay_step = parameter(delay_step, delay_target.shape, allow_none=False)
+ delay_type = 'heter'
+ else:
+ raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support '
+ f'integer, array of integers, callable function, brainpy.init.Initializer.')
+ if delay_type == 'heter':
+ if delay_step.dtype not in [bm.int32, bm.int64]:
+ raise ValueError('Only support delay steps of int32, int64. If your '
+ 'provide delay time length, please divide the "dt" '
+ 'then provide us the number of delay steps.')
+ if delay_target.shape[0] != delay_step.shape[0]:
+ raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}')
+ if delay_type != 'none':
+ max_delay_step = int(bm.max(delay_step))
+
+ # delay target
+ if delay_type != 'none':
+ if not isinstance(delay_target, bm.Variable):
+ raise ValueError(f'"delay_target" must be an instance of Variable, but we got {type(delay_target)}')
+
+ # delay variable
+ # TODO
+ if delay_type != 'none':
+ if identifier not in global_delay_data:
+ delay = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
+ global_delay_data[identifier] = (delay, delay_target)
+ self.local_delay_vars[identifier] = delay
+ else:
+ delay = global_delay_data[identifier][0]
+ if delay is None:
+ delay = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
+ global_delay_data[identifier] = (delay, delay_target)
+ self.local_delay_vars[identifier] = delay
+ elif delay.num_delay_step - 1 < max_delay_step:
+ global_delay_data[identifier][0].reset(delay_target, max_delay_step, initial_delay_data)
+ else:
+ if identifier not in global_delay_data:
+ global_delay_data[identifier] = (None, delay_target)
+ return delay_step
+
+ def get_delay_data(
+ self,
+ identifier: str,
+ delay_step: Optional[Union[int, bm.Array, jax.Array]],
+ *indices: Union[int, slice, bm.Array, jax.Array],
+ ):
+ """Get delay data according to the provided delay steps.
+
+ Parameters
+ ----------
+ identifier: str
+ The delay variable name.
+ delay_step: Optional, int, ArrayType
+ The delay length.
+ indices: optional, int, slice, ArrayType
+ The indices of the delay.
+
+ Returns
+ -------
+ delay_data: ArrayType
+ The delay data at the given time.
+ """
+ if delay_step is None:
+ return global_delay_data[identifier][1].value
+
+ if identifier in global_delay_data:
+ if bm.ndim(delay_step) == 0:
+ return global_delay_data[identifier][0](delay_step, *indices)
+ else:
+ if len(indices) == 0:
+ indices = (bm.arange(delay_step.size),)
+ return global_delay_data[identifier][0](delay_step, *indices)
+
+ elif identifier in self.local_delay_vars:
+ if bm.ndim(delay_step) == 0:
+ return self.local_delay_vars[identifier](delay_step)
+ else:
+ if len(indices) == 0:
+ indices = (bm.arange(delay_step.size),)
+ return self.local_delay_vars[identifier](delay_step, *indices)
+
+ else:
+ raise ValueError(f'{identifier} is not defined in delay variables.')
+
+ def update_local_delays(self, nodes: Union[Sequence, Dict] = None):
+ """Update local delay variables.
+
+ This function should be called after updating neuron groups or delay sources.
+ For example, in a network model,
+
+
+ Parameters
+ ----------
+ nodes: sequence, dict
+ The nodes to update their delay variables.
+ """
+ global DynamicalSystem
+ if DynamicalSystem is None:
+ from brainpy._src.dynsys import DynamicalSystem
+
+ # update delays
+ if nodes is None:
+ nodes = tuple(self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values())
+ elif isinstance(nodes, dict):
+ nodes = tuple(nodes.values())
+ if not isinstance(nodes, (tuple, list)):
+ nodes = (nodes,)
+ for node in nodes:
+ for name in node.local_delay_vars:
+ delay = global_delay_data[name][0]
+ target = global_delay_data[name][1]
+ delay.update(target.value)
+
+ def reset_local_delays(self, nodes: Union[Sequence, Dict] = None):
+ """Reset local delay variables.
+
+ Parameters
+ ----------
+ nodes: sequence, dict
+ The nodes to Reset their delay variables.
+ """
+ global DynamicalSystem
+ if DynamicalSystem is None:
+ from brainpy._src.dynsys import DynamicalSystem
+
+ # reset delays
+ if nodes is None:
+ nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values()
+ elif isinstance(nodes, dict):
+ nodes = nodes.values()
+ for node in nodes:
+ for name in node.local_delay_vars:
+ delay = global_delay_data[name][0]
+ target = global_delay_data[name][1]
+ delay.reset(target.value)
+
+
+class BindCondData(MixIn):
+ """Bind temporary conductance data.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self._conductance = None
+
+ def bind_cond(self, conductance):
+ self._conductance = conductance
+
+ def unbind_cond(self):
+ self._conductance = None
+
+
+T = TypeVar('T')
+
+
+def get_type(types):
+ class NewType(type):
+ def __instancecheck__(self, other):
+ cls_of_other = other.__class__
+ return all([issubclass(cls_of_other, cls) for cls in types])
+
+ return NewType
+
+
+class _MetaUnionType(type):
+ def __new__(cls, name, bases, dct):
+ if isinstance(bases, type):
+ bases = (bases,)
+ elif isinstance(bases, (list, tuple)):
+ bases = tuple(bases)
+ for base in bases:
+ assert isinstance(base, type), f'Must be type. But got {base}'
+ else:
+ raise TypeError(f'Must be type. But got {bases}')
+ return super().__new__(cls, name, bases, dct)
+
+ def __instancecheck__(self, other):
+ cls_of_other = other.__class__
+ return all([issubclass(cls_of_other, cls) for cls in self.__bases__])
+
+ def __subclasscheck__(self, subclass):
+ return all([issubclass(subclass, cls) for cls in self.__bases__])
+
+
+class UnionType2(MixIn):
+ """Union type for multiple types.
+
+ >>> import brainpy as bp
+ >>>
+ >>> isinstance(bp.dyn.Expon(1), JointType[bp.DynamicalSystem, bp.mixin.ParamDesc, bp.mixin.AutoDelaySupp])
+ """
+
+ @classmethod
+ def __class_getitem__(cls, types: Union[type, Sequence[type]]) -> type:
+ return _MetaUnionType('UnionType', types, {})
+
+
+class _JointGenericAlias(_UnionGenericAlias, _root=True):
+ def __subclasscheck__(self, subclass):
+ return all([issubclass(subclass, cls) for cls in set(self.__args__)])
+
+
+@_SpecialForm
+def JointType(self, parameters):
+ """Joint type; JointType[X, Y] means either X or Y.
+
+ To define a union, use e.g. Union[int, str]. Details:
+ - The arguments must be types and there must be at least one.
+ - None as an argument is a special case and is replaced by
+ type(None).
+ - Unions of unions are flattened, e.g.::
+
+ JointType[JointType[int, str], float] == JointType[int, str, float]
+
+ - Unions of a single argument vanish, e.g.::
+
+ JointType[int] == int # The constructor actually returns int
+
+ - Redundant arguments are skipped, e.g.::
+
+ JointType[int, str, int] == JointType[int, str]
+
+ - When comparing unions, the argument order is ignored, e.g.::
+
+ JointType[int, str] == JointType[str, int]
+
+ - You cannot subclass or instantiate a union.
+ - You can use Optional[X] as a shorthand for JointType[X, None].
+ """
+ if parameters == ():
+ raise TypeError("Cannot take a Union of no types.")
+ if not isinstance(parameters, tuple):
+ parameters = (parameters,)
+ msg = "JointType[arg, ...]: each arg must be a type."
+ parameters = tuple(_type_check(p, msg) for p in parameters)
+ parameters = _remove_dups_flatten(parameters)
+ if len(parameters) == 1:
+ return parameters[0]
+ return _JointGenericAlias(self, parameters)
diff --git a/brainpy/_src/neurons/compat.py b/brainpy/_src/neurons/compat.py
deleted file mode 100644
index 8a0c750c3..000000000
--- a/brainpy/_src/neurons/compat.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# -*- coding: utf-8 -*-
-
-
-from .biological_models import HH, MorrisLecar, PinskyRinzelModel
-from .fractional_models import FractionalFHR, FractionalIzhikevich
-from .reduced_models import LIF, ExpIF, AdExIF, QuaIF, AdQuaIF, GIF, Izhikevich, HindmarshRose, FHN
-from .input_groups import SpikeTimeGroup, PoissonGroup
-from .noise_groups import OUProcess
-
-__all__ = [
- 'HH', 'MorrisLecar', 'PinskyRinzelModel',
- 'FractionalFHR', 'FractionalIzhikevich',
- 'LIF', 'ExpIF', 'AdExIF', 'QuaIF', 'AdQuaIF',
- 'GIF', 'Izhikevich', 'HindmarshRose', 'FHN',
- 'SpikeTimeGroup', 'PoissonGroup', 'OUProcess'
-]
diff --git a/brainpy/_src/neurons/input_groups.py b/brainpy/_src/neurons/input_groups.py
deleted file mode 100644
index e49645253..000000000
--- a/brainpy/_src/neurons/input_groups.py
+++ /dev/null
@@ -1,201 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from typing import Union, Sequence
-
-import jax
-import jax.numpy as jnp
-from brainpy._src.context import share
-import brainpy.math as bm
-from brainpy._src.dynsys import NeuGroupNS
-from brainpy._src.initialize import Initializer, parameter, variable_
-from brainpy.types import Shape, ArrayType
-
-__all__ = [
- 'InputGroup',
- 'OutputGroup',
- 'SpikeTimeGroup',
- 'PoissonGroup',
-]
-
-
-class InputGroup(NeuGroupNS):
- """Input neuron group for place holder.
-
- Parameters
- ----------
- size: int, tuple of int
- keep_size: bool
- mode: Mode
- name: str
- """
-
- def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- mode: bm.Mode = None,
- name: str = None,
- ):
- super(InputGroup, self).__init__(name=name,
- size=size,
- keep_size=keep_size,
- mode=mode)
- self.spike = None
-
- def update(self, x):
- return x
-
- def reset_state(self, batch_size=None):
- pass
-
-
-class OutputGroup(NeuGroupNS):
- """Output neuron group for place holder.
-
- Parameters
- ----------
- size: int, tuple of int
- keep_size: bool
- mode: Mode
- name: str
- """
-
- def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- mode: bm.Mode = None,
- name: str = None,
- ):
- super(OutputGroup, self).__init__(name=name,
- size=size,
- keep_size=keep_size,
- mode=mode)
- self.spike = None
-
- def update(self, x):
- return x
-
- def reset_state(self, batch_size=None):
- pass
-
-
-class SpikeTimeGroup(NeuGroupNS):
- """The input neuron group characterized by spikes emitting at given times.
-
- >>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
- >>> SpikeTimeGroup(2, times=[10, 20])
- >>> # or
- >>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
- >>> SpikeTimeGroup(2, times=[10, 20], indices=[0, 0])
- >>> # or
- >>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
- >>> SpikeTimeGroup(2, times=[10, 20, 30], indices=[0, 1, 0])
- >>> # or
- >>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
- >>> # at 30 ms, neuron 1 fires.
- >>> SpikeTimeGroup(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
-
- Parameters
- ----------
- size : int, tuple, list
- The neuron group geometry.
- indices : list, tuple, ArrayType
- The neuron indices at each time point to emit spikes.
- times : list, tuple, ArrayType
- The time points which generate the spikes.
- name : str, optional
- The name of the dynamic system.
- """
-
- def __init__(
- self,
- size: Shape,
- times: Union[Sequence, ArrayType],
- indices: Union[Sequence, ArrayType],
- need_sort: bool = True,
- keep_size: bool = False,
- mode: bm.Mode = None,
- name: str = None
- ):
- super(SpikeTimeGroup, self).__init__(size=size,
- name=name,
- keep_size=keep_size,
- mode=mode)
-
- # parameters
- if keep_size:
- raise NotImplementedError(f'Do not support keep_size=True in {self.__class__.__name__}')
- if len(indices) != len(times):
- raise ValueError(f'The length of "indices" and "times" must be the same. '
- f'However, we got {len(indices)} != {len(times)}.')
- self.num_times = len(times)
-
- # data about times and indices
- self.times = bm.asarray(times)
- self.indices = bm.asarray(indices, dtype=bm.int_)
- if need_sort:
- sort_idx = bm.argsort(self.times)
- self.indices.value = self.indices[sort_idx]
- self.times.value = self.times[sort_idx]
-
- # variables
- self.reset_state(self.mode)
-
- def reset_state(self, batch_size=None):
- self.i = bm.Variable(bm.asarray(0))
- self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size)
-
- def update(self):
- self.spike.value = bm.zeros_like(self.spike)
- bm.while_loop(self._body_fun, self._cond_fun, share.load('t'))
- return self.spike.value
-
- # functions
- def _cond_fun(self, t):
- i = self.i.value
- return bm.logical_and(i < self.num_times, t >= self.times[i])
-
- def _body_fun(self, t):
- i = self.i.value
- if isinstance(self.mode, bm.BatchingMode):
- self.spike[:, self.indices[i]] = True
- else:
- self.spike[self.indices[i]] = True
- self.i += 1
-
-
-class PoissonGroup(NeuGroupNS):
- """Poisson Neuron Group.
- """
-
- def __init__(
- self,
- size: Shape,
- freqs: Union[int, float, jnp.ndarray, bm.Array, Initializer],
- seed: int = None,
- keep_size: bool = False,
- mode: bm.Mode = None,
- name: str = None
- ):
- super(PoissonGroup, self).__init__(size=size,
- name=name,
- keep_size=keep_size,
- mode=mode)
-
- # parameters
- self.keep_size = keep_size
- self.seed = seed
- self.freqs = parameter(freqs, self.num, allow_none=False)
-
- # variables
- self.reset_state(self.mode)
-
- def update(self):
- spikes = bm.random.rand_like(self.spike) <= (self.freqs * share.dt / 1000.)
- self.spike.value = spikes
- return spikes
-
- def reset_state(self, batch_size=None):
- self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size)
-
diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py
index 91c898701..1bfd9cc61 100644
--- a/brainpy/_src/runners.py
+++ b/brainpy/_src/runners.py
@@ -83,27 +83,19 @@ def check_and_format_inputs(host, inputs):
# checking 1: absolute access
# Check whether the input target node is accessible,
# and check whether the target node has the attribute
- nodes = None
for one_input in inputs:
key = one_input[0]
if isinstance(key, bm.Variable):
real_target = key
elif isinstance(key, str):
- if nodes is None:
- nodes = host.nodes(method='absolute', level=-1, include_self=True)
splits = key.split('.')
- target = '.'.join(splits[:-1])
- key = splits[-1]
- if target == '':
- real_target = host
- else:
- if target not in nodes:
- inputs_not_found_target.append(one_input)
- continue
- real_target = nodes[target]
- if not hasattr(real_target, key):
- raise RunningError(f'Input target key "{key}" is not defined in {real_target}.')
- real_target = getattr(real_target, key)
+ target = host
+ try:
+ for split in splits:
+ target = getattr(target, split)
+ except AttributeError:
+ raise AttributeError(f'target {target} does not have "{split}"')
+ real_target = target
else:
raise RunningError(f'For each input, input[0] must be a string to '
f'specify variable of the target, but we got {key}.')
@@ -112,18 +104,18 @@ def check_and_format_inputs(host, inputs):
# checking 2: relative access
# Check whether the input target node is accessible
# and check whether the target node has the attribute
- if len(inputs_not_found_target):
- nodes = host.nodes(method='relative', level=-1, include_self=True)
- for one_input in inputs_not_found_target:
- splits = one_input[0].split('.')
- target, key = '.'.join(splits[:-1]), splits[-1]
- if target not in nodes:
- raise RunningError(f'Input target "{target}" is not defined in {host}.')
- real_target = nodes[target]
- if not hasattr(real_target, key):
- raise RunningError(f'Input target key "{key}" is not defined in {real_target}.')
- real_target = getattr(real_target, key)
- inputs_which_found_target.append((real_target,) + tuple(one_input[1:]))
+ # if len(inputs_not_found_target):
+ # nodes = host.nodes(method='relative', level=-1, include_self=True)
+ # for one_input in inputs_not_found_target:
+ # splits = one_input[0].split('.')
+ # target, key = '.'.join(splits[:-1]), splits[-1]
+ # if target not in nodes:
+ # raise RunningError(f'Input target "{target}" is not defined in {host}.')
+ # real_target = nodes[target]
+ # if not hasattr(real_target, key):
+ # raise RunningError(f'Input target key "{key}" is not defined in {real_target}.')
+ # real_target = getattr(real_target, key)
+ # inputs_which_found_target.append((real_target,) + tuple(one_input[1:]))
# 3. format inputs
# ---------
diff --git a/brainpy/_src/synapses/biological_models.py b/brainpy/_src/synapses/biological_models.py
deleted file mode 100644
index 9bf9c1c03..000000000
--- a/brainpy/_src/synapses/biological_models.py
+++ /dev/null
@@ -1,587 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from typing import Union, Dict, Callable, Optional
-
-from jax import vmap
-from jax.lax import stop_gradient
-
-import brainpy.math as bm
-from brainpy._src.dynsys import NeuGroup, TwoEndConn, SynSTP, SynOut
-from brainpy._src.synouts import COBA, MgBlock
-from brainpy._src.initialize import Initializer, variable
-from brainpy._src.integrators import odeint, JointEq
-from brainpy._src.connect import TwoEndConnector, All2All, One2One
-from brainpy.types import ArrayType
-
-__all__ = [
- 'AMPA',
- 'GABAa',
- 'BioNMDA',
-]
-
-
-class AMPA(TwoEndConn):
- r"""AMPA synapse model.
-
- **Model Descriptions**
-
- AMPA receptor is an ionotropic receptor, which is an ion channel.
- When it is bound by neurotransmitters, it will immediately open the
- ion channel, causing the change of membrane potential of postsynaptic neurons.
-
- A classical model is to use the Markov process to model ion channel switch.
- Here :math:`g` represents the probability of channel opening, :math:`1-g`
- represents the probability of ion channel closing, and :math:`\alpha` and
- :math:`\beta` are the transition probability. Because neurotransmitters can
- open ion channels, the transfer probability from :math:`1-g` to :math:`g`
- is affected by the concentration of neurotransmitters. We denote the concentration
- of neurotransmitters as :math:`[T]` and get the following Markov process.
-
- .. image:: ../../../_static/synapse_markov.png
- :align: center
-
- We obtained the following formula when describing the process by a differential equation.
-
- .. math::
-
- \frac{ds}{dt} =\alpha[T](1-g)-\beta g
-
- where :math:`\alpha [T]` denotes the transition probability from state :math:`(1-g)`
- to state :math:`(g)`; and :math:`\beta` represents the transition probability of
- the other direction. :math:`\alpha` is the binding constant. :math:`\beta` is the
- unbinding constant. :math:`[T]` is the neurotransmitter concentration, and
- has the duration of 0.5 ms.
-
- Moreover, the post-synaptic current on the post-synaptic neuron is formulated as
-
- .. math::
-
- I_{syn} = g_{max} g (V-E)
-
- where :math:`g_{max}` is the maximum conductance, and `E` is the reverse potential.
-
- **Model Examples**
-
-
- .. plot::
- :include-source: True
-
- >>> import brainpy as bp
- >>> from brainpy import neurons, synapses
- >>> import matplotlib.pyplot as plt
- >>>
- >>> neu1 = neurons.HH(1)
- >>> neu2 = neurons.HH(1)
- >>> syn1 = synapses.AMPA(neu1, neu2, bp.connect.All2All())
- >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
- >>>
- >>> runner = bp.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g'])
- >>> runner.run(150.)
- >>>
- >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
- >>> fig.add_subplot(gs[0, 0])
- >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
- >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
- >>> plt.legend()
- >>>
- >>> fig.add_subplot(gs[1, 0])
- >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
- >>> plt.legend()
- >>> plt.show()
-
- Parameters
- ----------
- pre: NeuGroup
- The pre-synaptic neuron group.
- post: NeuGroup
- The post-synaptic neuron group.
- conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
- The synaptic connections.
- comp_method: str
- The connection type used for model speed optimization. It can be
- `sparse` and `dense`. The default is `dense`.
- delay_step: int, ArrayType, Initializer, Callable
- The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
- E: float, ArrayType
- The reversal potential for the synaptic current. [mV]
-
- .. deprecated:: 2.1.13
- `E` is deprecated in AMPA model. Please define `E` with brainpy.dyn.synouts.COBA.
- This parameter will be removed since 2.2.0
-
- g_max: float, ArrayType, Initializer, Callable
- The synaptic strength (the maximum conductance). Default is 1.
- alpha: float, ArrayType
- Binding constant.
- beta: float, ArrayType
- Unbinding constant.
- T: float, ArrayType
- Transmitter concentration when synapse is triggered by
- a pre-synaptic spike.. Default 1 [mM].
- T_duration: float, ArrayType
- Transmitter concentration duration time after being triggered. Default 1 [ms]
- name: str
- The name of this synaptic projection.
- method: str
- The numerical integration methods.
-
- References
- ----------
-
- .. [1] Vijayan S, Kopell N J. Thalamic model of awake alpha oscillations
- and implications for stimulus processing[J]. Proceedings of the
- National Academy of Sciences, 2012, 109(45): 18553-18558.
- """
-
- def __init__(
- self,
- pre: NeuGroup,
- post: NeuGroup,
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- output: SynOut = COBA(E=0.),
- stp: Optional[SynSTP] = None,
- comp_method: str = 'dense',
- g_max: Union[float, ArrayType, Initializer, Callable] = 0.42,
- delay_step: Union[int, ArrayType, Initializer, Callable] = None,
- alpha: float = 0.98,
- beta: float = 0.18,
- T: float = 0.5,
- T_duration: float = 0.5,
- method: str = 'exp_auto',
-
- # other parameters
- name: str = None,
- mode: bm.Mode = None,
- stop_spike_gradient: bool = False,
- ):
- super(AMPA, self).__init__(pre=pre,
- post=post,
- conn=conn,
- output=output,
- stp=stp,
- name=name,
- mode=mode)
-
- # parameters
- self.stop_spike_gradient = stop_spike_gradient
- self.comp_method = comp_method
- self.alpha = alpha
- self.beta = beta
- self.T = T
- self.T_duration = T_duration
- if bm.size(alpha) != 1:
- raise ValueError(f'"alpha" must be a scalar or a tensor with size of 1. But we got {alpha}')
- if bm.size(beta) != 1:
- raise ValueError(f'"beta" must be a scalar or a tensor with size of 1. But we got {beta}')
- if bm.size(T) != 1:
- raise ValueError(f'"T" must be a scalar or a tensor with size of 1. But we got {T}')
- if bm.size(T_duration) != 1:
- raise ValueError(f'"T_duration" must be a scalar or a tensor with size of 1. But we got {T_duration}')
-
- # connection
- self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')
-
- # variables
- self.g = variable(bm.zeros, self.mode, self.pre.num)
- self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, self.mode, self.pre.num)
- self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
-
- # functions
- self.integral = odeint(method=method, f=self.dg)
-
- def reset_state(self, batch_size=None):
- self.g = variable(bm.zeros, batch_size, self.pre.num)
- self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.pre.num)
- self.output.reset_state(batch_size)
- if self.stp is not None: self.stp.reset_state(batch_size)
-
- def dg(self, g, t, TT):
- dg = self.alpha * TT * (1 - g) - self.beta * g
- return dg
-
- def update(self, tdi, pre_spike=None):
- t, dt = tdi['t'], tdi['dt']
-
- # delays
- if pre_spike is None:
- pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
- pre_spike = bm.as_jax(pre_spike)
- if self.stop_spike_gradient:
- pre_spike = stop_gradient(pre_spike)
-
- # update sub-components
- self.output.update(tdi)
- if self.stp is not None: self.stp.update(tdi, pre_spike)
-
- # update synaptic variables
- self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time.value)
- if isinstance(self.mode, bm.TrainingMode):
- self.spike_arrival_time.value = stop_gradient(self.spike_arrival_time.value)
- TT = ((t - self.spike_arrival_time) < self.T_duration) * self.T
- self.g.value = self.integral(self.g, t, TT, dt)
-
- # post-synaptic values
- syn_value = self.g.value
- if self.stp is not None: syn_value = self.stp(syn_value)
- if isinstance(self.conn, All2All):
- post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
- elif isinstance(self.conn, One2One):
- post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
- else:
- if self.comp_method == 'sparse':
- f = lambda s: bm.sparse.csrmv(
- self.g_max, self.conn_mask[0], self.conn_mask[1], s,
- shape=(self.pre.num, self.post.num),
- transpose=True
- )
- if isinstance(self.mode, bm.BatchingMode):
- f = vmap(f)
- post_vs = f(syn_value)
- else:
- post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
-
- # output
- return self.output(post_vs)
-
-
-class GABAa(AMPA):
- r"""GABAa synapse model.
-
- **Model Descriptions**
-
- GABAa synapse model has the same equation with the `AMPA synapse <./brainmodels.synapses.AMPA.rst>`_,
-
- .. math::
-
- \frac{d g}{d t}&=\alpha[T](1-g) - \beta g \\
- I_{syn}&= - g_{max} g (V - E)
-
- but with the difference of:
-
- - Reversal potential of synapse :math:`E` is usually low, typically -80. mV
- - Activating rate constant :math:`\alpha=0.53`
- - De-activating rate constant :math:`\beta=0.18`
- - Transmitter concentration :math:`[T]=1\,\mu ho(\mu S)` when synapse is
- triggered by a pre-synaptic spike, with the duration of 1. ms.
-
- **Model Examples**
-
- - `Gamma oscillation network model `_
-
-
- Parameters
- ----------
- pre: NeuGroup
- The pre-synaptic neuron group.
- post: NeuGroup
- The post-synaptic neuron group.
- conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
- The synaptic connections.
- comp_method: str
- The connection type used for model speed optimization. It can be
- `sparse` and `dense`. The default is `dense`.
- delay_step: int, ArrayType, Initializer, Callable
- The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
- g_max: float, ArrayType, Initializer, Callable
- The synaptic strength (the maximum conductance). Default is 1.
- alpha: float, ArrayType
- Binding constant. Default 0.062
- beta: float, ArrayType
- Unbinding constant. Default 3.57
- T: float, ArrayType
- Transmitter concentration when synapse is triggered by
- a pre-synaptic spike.. Default 1 [mM].
- T_duration: float, ArrayType
- Transmitter concentration duration time after being triggered. Default 1 [ms]
- name: str
- The name of this synaptic projection.
- method: str
- The numerical integration methods.
-
- References
- ----------
- .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity
- on the integrative properties of neocortical pyramidal neurons
- in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547.
- """
-
- def __init__(
- self,
- pre: NeuGroup,
- post: NeuGroup,
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- output: SynOut = COBA(E=-80.),
- stp: Optional[SynSTP] = None,
- comp_method: str = 'dense',
- g_max: Union[float, ArrayType, Initializer, Callable] = 0.04,
- delay_step: Union[int, ArrayType, Initializer, Callable] = None,
- alpha: Union[float, ArrayType] = 0.53,
- beta: Union[float, ArrayType] = 0.18,
- T: Union[float, ArrayType] = 1.,
- T_duration: Union[float, ArrayType] = 1.,
- method: str = 'exp_auto',
-
- # other parameters
- name: str = None,
- mode: bm.Mode = None,
- stop_spike_gradient: bool = False,
- ):
- super(GABAa, self).__init__(pre=pre,
- post=post,
- conn=conn,
- output=output,
- stp=stp,
- comp_method=comp_method,
- delay_step=delay_step,
- g_max=g_max,
- alpha=alpha,
- beta=beta,
- T=T,
- T_duration=T_duration,
- method=method,
- name=name,
- mode=mode,
- stop_spike_gradient=stop_spike_gradient, )
-
-
-class BioNMDA(TwoEndConn):
- r"""Biological NMDA synapse model.
-
- **Model Descriptions**
-
- The NMDA receptor is a glutamate receptor and ion channel found in neurons.
- The NMDA receptor is one of three types of ionotropic glutamate receptors,
- the other two being AMPA and kainate receptors.
-
- The NMDA receptor mediated conductance depends on the postsynaptic voltage.
- The voltage dependence is due to the blocking of the pore of the NMDA receptor
- from the outside by a positively charged magnesium ion. The channel is
- nearly completely blocked at resting potential, but the magnesium block is
- relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}`
- that are not blocked by magnesium can be fitted to
-
- .. math::
-
- g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-a V}
- \frac{[{Mg}^{2+}]_{o}} {b})^{-1}
-
- Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration,
- usually 1 mM. Thus, the channel acts as a
- "coincidence detector" and only once both of these conditions are met, the
- channel opens and it allows positively charged ions (cations) to flow through
- the cell membrane [2]_.
-
- If we make the approximation that the magnesium block changes
- instantaneously with voltage and is independent of the gating of the channel,
- the net NMDA receptor-mediated synaptic current is given by
-
- .. math::
-
- I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty}
-
- where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the
- reversal potential.
-
- Simultaneously, the kinetics of synaptic state :math:`g` is determined by a 2nd-order kinetics [1]_:
-
- .. math::
-
- & g_\mathrm{NMDA} (t) = g_{max} g \\
- & \frac{d g}{dt} = \alpha_1 x (1 - g) - \beta_1 g \\
- & \frac{d x}{dt} = \alpha_2 [T] (1 - x) - \beta_2 x
-
- where :math:`\alpha_1, \beta_1` refers to the conversion rate of variable g and
- :math:`\alpha_2, \beta_2` refers to the conversion rate of variable x.
-
- The NMDA receptor has been thought to be very important for controlling
- synaptic plasticity and mediating learning and memory functions [3]_.
-
- .. plot::
- :include-source: True
-
- >>> import brainpy as bp
- >>> from brainpy import neurons, synapses
- >>> import matplotlib.pyplot as plt
- >>>
- >>> neu1 = neurons.HH(1)
- >>> neu2 = neurons.HH(1)
- >>> syn1 = synapses.BioNMDA(neu1, neu2, bp.connect.All2All())
- >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
- >>>
- >>> runner = bp.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x'])
- >>> runner.run(150.)
- >>>
- >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
- >>> fig.add_subplot(gs[0, 0])
- >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
- >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
- >>> plt.legend()
- >>>
- >>> fig.add_subplot(gs[1, 0])
- >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
- >>> plt.plot(runner.mon.ts, runner.mon['syn.x'], label='x')
- >>> plt.legend()
- >>> plt.show()
-
- Parameters
- ----------
- pre: NeuGroup
- The pre-synaptic neuron group.
- post: NeuGroup
- The post-synaptic neuron group.
- conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
- The synaptic connections.
- comp_method: str
- The connection type used for model speed optimization. It can be
- `sparse` and `dense`. The default is `dense`.
- delay_step: int, ArrayType, Initializer, Callable
- The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
- g_max: float, ArrayType, Initializer, Callable
- The synaptic strength (the maximum conductance). Default is 1.
- alpha1: float, ArrayType
- The conversion rate of g from inactive to active. Default 2 ms^-1.
- beta1: float, ArrayType
- The conversion rate of g from active to inactive. Default 0.01 ms^-1.
- alpha2: float, ArrayType
- The conversion rate of x from inactive to active. Default 1 ms^-1.
- beta2: float, ArrayType
- The conversion rate of x from active to inactive. Default 0.5 ms^-1.
- name: str
- The name of this synaptic projection.
- method: str
- The numerical integration methods.
-
- References
- ----------
-
- .. [1] Devaney A J . Mathematical Foundations of Neuroscience[M].
- Springer New York, 2010: 162.
- .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and
- Eric Gouaux. "Subunit arrangement and function in NMDA receptors."
- Nature 438, no. 7065 (2005): 185-192.
- .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New
- England journal of medicine, 361(3), p.302.
- .. [4] https://en.wikipedia.org/wiki/NMDA_receptor
-
- """
-
- def __init__(
- self,
- pre: NeuGroup,
- post: NeuGroup,
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- output: SynOut = MgBlock(E=0.),
- stp: Optional[SynSTP] = None,
- comp_method: str = 'dense',
- g_max: Union[float, ArrayType, Initializer, Callable] = 0.15,
- delay_step: Union[int, ArrayType, Initializer, Callable] = None,
- alpha1: Union[float, ArrayType] = 2.,
- beta1: Union[float, ArrayType] = 0.01,
- alpha2: Union[float, ArrayType] = 1.,
- beta2: Union[float, ArrayType] = 0.5,
- T_0: Union[float, ArrayType] = 1.,
- T_dur: Union[float, ArrayType] = 0.5,
- method: str = 'exp_auto',
-
- # other parameters
- mode: bm.Mode = None,
- name: str = None,
- stop_spike_gradient: bool = False,
- ):
- super(BioNMDA, self).__init__(pre=pre,
- post=post,
- conn=conn,
- output=output,
- stp=stp,
- name=name,
- mode=mode)
-
- # parameters
- self.beta1 = beta1
- self.beta2 = beta2
- self.alpha1 = alpha1
- self.alpha2 = alpha2
- self.T_0 = T_0
- self.T_dur = T_dur
- if bm.size(alpha1) != 1:
- raise ValueError(f'"alpha1" must be a scalar or a tensor with size of 1. But we got {alpha1}')
- if bm.size(beta1) != 1:
- raise ValueError(f'"beta1" must be a scalar or a tensor with size of 1. But we got {beta1}')
- if bm.size(alpha2) != 1:
- raise ValueError(f'"alpha2" must be a scalar or a tensor with size of 1. But we got {alpha2}')
- if bm.size(beta2) != 1:
- raise ValueError(f'"beta2" must be a scalar or a tensor with size of 1. But we got {beta2}')
- if bm.size(T_0) != 1:
- raise ValueError(f'"T_0" must be a scalar or a tensor with size of 1. But we got {T_0}')
- if bm.size(T_dur) != 1:
- raise ValueError(f'"T_dur" must be a scalar or a tensor with size of 1. But we got {T_dur}')
- self.comp_method = comp_method
- self.stop_spike_gradient = stop_spike_gradient
-
- # connections and weights
- self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')
-
- # variables
- self.g = variable(bm.zeros, self.mode, self.pre.num)
- self.x = variable(bm.zeros, self.mode, self.pre.num)
- self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, self.mode, self.pre.num)
- self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
-
- # integral
- self.integral = odeint(method=method, f=JointEq([self.dg, self.dx]))
-
- def reset_state(self, batch_size=None):
- self.g = variable(bm.zeros, batch_size, self.pre.num)
- self.x = variable(bm.zeros, batch_size, self.pre.num)
- self.spike_arrival_time = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.pre.num)
- self.output.reset_state(batch_size)
- if self.stp is not None: self.stp.reset_state(batch_size)
-
- def dg(self, g, t, x):
- return self.alpha1 * x * (1 - g) - self.beta1 * g
-
- def dx(self, x, t, T):
- return self.alpha2 * T * (1 - x) - self.beta2 * x
-
- def update(self, tdi, pre_spike=None):
- t, dt = tdi['t'], tdi['dt']
-
- # pre-synaptic spikes
- if pre_spike is None:
- pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
- pre_spike = bm.as_jax(pre_spike)
- if self.stop_spike_gradient:
- pre_spike = stop_gradient(pre_spike)
-
- # update sub-components
- self.output.update(tdi)
- if self.stp is not None: self.stp.update(tdi, pre_spike)
-
- # update synapse variables
- self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time.value)
- if isinstance(self.mode, bm.TrainingMode):
- self.spike_arrival_time.value = stop_gradient(self.spike_arrival_time.value)
- T = ((t - self.spike_arrival_time) < self.T_dur) * self.T_0
- self.g.value, self.x.value = self.integral(self.g, self.x, t, T, dt)
-
- # post-synaptic value
- syn_value = self.g.value
- if self.stp is not None: syn_value = self.stp(syn_value)
- if isinstance(self.conn, All2All):
- post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
- elif isinstance(self.conn, One2One):
- post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
- else:
- if self.comp_method == 'sparse':
- f = lambda s: bm.sparse.csrmv(
- self.g_max,self.conn_mask[0], self.conn_mask[1], s,
- shape=(self.pre.num, self.post.num),
- transpose=True
- )
- if isinstance(self.mode, bm.BatchingMode): f = vmap(f)
- post_vs = f(syn_value)
- else:
- post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
-
- # output
- return self.output(post_vs)
diff --git a/brainpy/_src/synapses/compat.py b/brainpy/_src/synapses/compat.py
deleted file mode 100644
index 40b66b5c7..000000000
--- a/brainpy/_src/synapses/compat.py
+++ /dev/null
@@ -1,300 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import warnings
-from typing import Union, Dict, Callable, Optional
-
-import brainpy._src.math as bm
-from brainpy._src.connect import TwoEndConnector
-from brainpy._src.dynsys import NeuGroup, SynSTP
-from brainpy._src.synouts import COBA, CUBA, MgBlock
-from brainpy._src.initialize import Initializer
-from brainpy.types import ArrayType
-from .abstract_models import Delta, Exponential, DualExponential, NMDA as NewNMDA
-
-__all__ = [
- 'DeltaSynapse',
- 'ExpCUBA',
- 'ExpCOBA',
- 'DualExpCUBA',
- 'DualExpCOBA',
- 'AlphaCUBA',
- 'AlphaCOBA',
- 'NMDA',
-]
-
-
-class DeltaSynapse(Delta):
- """Delta synapse.
-
- .. deprecated:: 2.1.13
- Please use "brainpy.synapses.Delta" instead.
-
- """
-
- def __init__(
- self,
- pre: NeuGroup,
- post: NeuGroup,
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- conn_type: str = 'sparse',
- weights: Union[float, ArrayType, Initializer, Callable] = 1.,
- delay_step: Union[float, ArrayType, Initializer, Callable] = None,
- post_input_key: str = 'V',
- post_has_ref: bool = False,
- name: str = None,
- ):
- warnings.warn('Please use "brainpy.synapses.Delta" instead.', DeprecationWarning)
- super(DeltaSynapse, self).__init__(pre=pre,
- post=post,
- conn=conn,
- output=CUBA(post_input_key),
- name=name,
- comp_method=conn_type,
- g_max=weights,
- delay_step=delay_step,
- post_ref_key='refractory' if post_has_ref else None)
-
-
-class ExpCUBA(Exponential):
- r"""Current-based exponential decay synapse model.
-
- .. deprecated:: 2.1.13
- Please use "brainpy.synapses.Exponential" instead.
-
- """
-
- def __init__(
- self,
- pre: NeuGroup,
- post: NeuGroup,
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- conn_type: str = 'sparse',
- g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
- delay_step: Union[int, ArrayType, Initializer, Callable] = None,
- tau: Union[float, ArrayType] = 8.0,
- name: str = None,
- method: str = 'exp_auto',
- ):
- super(ExpCUBA, self).__init__(pre=pre,
- post=post,
- conn=conn,
- name=name,
- comp_method=conn_type,
- g_max=g_max,
- delay_step=delay_step,
- tau=tau,
- method=method,
- output=CUBA())
-
-
-class ExpCOBA(Exponential):
- """Conductance-based exponential decay synapse model.
-
- .. deprecated:: 2.1.13
- Please use "brainpy.synapses.Exponential" instead.
- """
-
- def __init__(
- self,
- pre: NeuGroup,
- post: NeuGroup,
- # connection
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- conn_type: str = 'sparse',
- # connection strength
- g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
- # synapse parameter
- tau: Union[float, ArrayType] = 8.0,
- E: Union[float, ArrayType] = 0.,
- # synapse delay
- delay_step: Union[int, ArrayType, Initializer, Callable] = None,
- # others
- method: str = 'exp_auto',
- name: str = None
- ):
- super(ExpCOBA, self).__init__(pre=pre,
- post=post,
- conn=conn,
- comp_method=conn_type,
- g_max=g_max,
- delay_step=delay_step,
- tau=tau,
- method=method,
- name=name,
- output=COBA(E=E))
-
-
-class DualExpCUBA(DualExponential):
- r"""Current-based dual exponential synapse model.
-
- .. deprecated:: 2.1.13
- Please use "brainpy.synapses.DualExponential" instead.
-
- """
-
- def __init__(
- self,
- pre: NeuGroup,
- post: NeuGroup,
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- conn_type: str = 'dense',
- g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
- tau_decay: Union[float, ArrayType] = 10.0,
- tau_rise: Union[float, ArrayType] = 1.,
- delay_step: Union[int, ArrayType, Initializer, Callable] = None,
- method: str = 'exp_auto',
- name: str = None
- ):
- super(DualExpCUBA, self).__init__(pre=pre,
- post=post,
- conn=conn,
- comp_method=conn_type,
- g_max=g_max,
- tau_decay=tau_decay,
- tau_rise=tau_rise,
- delay_step=delay_step,
- method=method,
- name=name,
- output=CUBA())
-
-
-class DualExpCOBA(DualExponential):
- """Conductance-based dual exponential synapse model.
-
-
- .. deprecated:: 2.1.13
- Please use "brainpy.synapses.DualExponential" instead.
-
- """
-
- def __init__(
- self,
- pre: NeuGroup,
- post: NeuGroup,
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- conn_type: str = 'dense',
- g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
- delay_step: Union[int, ArrayType, Initializer, Callable] = None,
- tau_decay: Union[float, ArrayType] = 10.0,
- tau_rise: Union[float, ArrayType] = 1.,
- E: Union[float, ArrayType] = 0.,
- method: str = 'exp_auto',
- name: str = None
- ):
- super(DualExpCOBA, self).__init__(pre=pre,
- post=post,
- conn=conn,
- comp_method=conn_type,
- g_max=g_max,
- tau_decay=tau_decay,
- tau_rise=tau_rise,
- delay_step=delay_step,
- method=method,
- name=name,
- output=COBA(E=E))
-
-
-class AlphaCUBA(DualExpCUBA):
- r"""Current-based alpha synapse model.
-
- .. deprecated:: 2.1.13
- Please use "brainpy.synapses.Alpha" instead.
-
- """
-
- def __init__(
- self,
- pre: NeuGroup,
- post: NeuGroup,
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- conn_type: str = 'dense',
- g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
- delay_step: Union[int, ArrayType, Initializer, Callable] = None,
- tau_decay: Union[float, ArrayType] = 10.0,
- method: str = 'exp_auto',
- name: str = None
- ):
- super(AlphaCUBA, self).__init__(pre=pre,
- post=post,
- conn=conn,
- conn_type=conn_type,
- delay_step=delay_step,
- g_max=g_max,
- tau_decay=tau_decay,
- tau_rise=tau_decay,
- method=method,
- name=name)
-
-
-class AlphaCOBA(DualExpCOBA):
- """Conductance-based alpha synapse model.
-
- .. deprecated:: 2.1.13
- Please use "brainpy.synapses.Alpha" instead.
-
- """
-
- def __init__(
- self,
- pre: NeuGroup,
- post: NeuGroup,
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- conn_type: str = 'dense',
- g_max: Union[float, ArrayType, Callable, Initializer] = 1.,
- delay_step: Union[int, ArrayType, Initializer, Callable] = None,
- tau_decay: Union[float, ArrayType] = 10.0,
- E: Union[float, ArrayType] = 0.,
- method: str = 'exp_auto',
- name: str = None
- ):
- super(AlphaCOBA, self).__init__(pre=pre,
- post=post,
- conn=conn,
- conn_type=conn_type,
- delay_step=delay_step,
- g_max=g_max, E=E,
- tau_decay=tau_decay,
- tau_rise=tau_decay,
- method=method,
- name=name)
-
-
-class NMDA(NewNMDA):
- def __init__(
- self,
- pre: NeuGroup,
- post: NeuGroup,
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- E=0.,
- alpha=0.062,
- beta=3.57,
- cc_Mg=1.2,
- stp: Optional[SynSTP] = None,
- comp_method: str = 'dense',
- g_max: Union[float, ArrayType, Initializer, Callable] = 0.15,
- delay_step: Union[int, ArrayType, Initializer, Callable] = None,
- tau_decay: Union[float, ArrayType] = 100.,
- a: Union[float, ArrayType] = 0.5,
- tau_rise: Union[float, ArrayType] = 2.,
- method: str = 'exp_auto',
-
- # other parameters
- name: str = None,
- mode: bm.Mode = None,
- stop_spike_gradient: bool = False,
- ):
- super(NMDA, self).__init__(pre=pre,
- post=post,
- conn=conn,
- output=MgBlock(E=E, alpha=alpha, beta=beta, cc_Mg=cc_Mg),
- stp=stp,
- name=name,
- mode=mode,
- comp_method=comp_method,
- g_max=g_max,
- delay_step=delay_step,
- tau_decay=tau_decay,
- a=a,
- tau_rise=tau_rise,
- method=method,
- stop_spike_gradient=stop_spike_gradient)
diff --git a/brainpy/_src/synapses/tests/test_abstract_synapses.py b/brainpy/_src/synapses/tests/test_abstract_synapses.py
deleted file mode 100644
index a714b493c..000000000
--- a/brainpy/_src/synapses/tests/test_abstract_synapses.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# -*- coding: utf-8 -*-
-
-
-import brainpy as bp
-import brainpy.math as bm
-from absl.testing import parameterized
-from brainpy._src.synapses import abstract_models
-
-
-class Test_Abstract_Synapse(parameterized.TestCase):
- @parameterized.named_parameters(
- {'testcase_name': f'noise_of_{name}', 'synapse': name}
- for name in ['Exponential', 'DualExponential', 'Alpha', 'NMDA']
- )
- def test_all2all_synapse(self, synapse):
- pre_neu = bp.neurons.LIF(5)
- post_neu = bp.neurons.LIF(5)
- syn_model = getattr(abstract_models, synapse)
- syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All())
- net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
-
- # 运行模拟
- runner = bp.DSRunner(net,
- monitors=['pre.V', 'syn.g', 'post.V'],
- inputs=('pre.input', 35.))
- runner(10.)
- self.assertTupleEqual(runner.mon['pre.V'].shape, (100, 5))
- self.assertTupleEqual(runner.mon['syn.g'].shape, (100, 5))
- self.assertTupleEqual(runner.mon['post.V'].shape, (100, 5))
-
- @parameterized.named_parameters(
- {'testcase_name': f'noise_of_{name}', 'synapse': name}
- for name in ['Exponential', 'DualExponential', 'Alpha', 'NMDA']
- )
- def test_one2one_synapse(self, synapse):
- pre_neu = bp.neurons.LIF(5)
- post_neu = bp.neurons.LIF(5)
- syn_model = getattr(abstract_models, synapse)
- syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One())
- net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
-
- # 运行模拟
- runner = bp.DSRunner(net,
- monitors=['pre.V', 'syn.g', 'post.V'],
- inputs=('pre.input', 35.))
- runner(10.)
- self.assertTupleEqual(runner.mon['pre.V'].shape, (100, 5))
- self.assertTupleEqual(runner.mon['syn.g'].shape, (100, 5))
- self.assertTupleEqual(runner.mon['post.V'].shape, (100, 5))
-
- @parameterized.named_parameters(
- {'testcase_name': f'noise_of_{name}', 'synapse': name}
- for name in ['Exponential', 'DualExponential', 'Alpha', 'NMDA']
- )
- def test_sparse_synapse(self, synapse):
- pre_neu = bp.neurons.LIF(5)
- post_neu = bp.neurons.LIF(5)
- syn_model = getattr(abstract_models, synapse)
- syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method='sparse')
- net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
-
- # 运行模拟
- runner = bp.DSRunner(net,
- monitors=['pre.V', 'syn.g', 'post.V'],
- inputs=('pre.input', 35.))
- runner(10.)
- self.assertTupleEqual(runner.mon['pre.V'].shape, (100, 5))
- self.assertTupleEqual(runner.mon['syn.g'].shape, (100, 5))
- self.assertTupleEqual(runner.mon['post.V'].shape, (100, 5))
-
-
- def test_delta_synapse(self):
- pre_neu = bp.neurons.LIF(5)
- post_neu = bp.neurons.LIF(3)
- syn_model = abstract_models.Delta
- syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All())
- net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
-
- # 运行模拟
- runner = bp.DSRunner(net,
- monitors=['pre.V', 'post.V'],
- inputs=('pre.input', 35.))
- runner(10.)
- self.assertTupleEqual(runner.mon['pre.V'].shape, (100, 5))
- self.assertTupleEqual(runner.mon['post.V'].shape, (100, 3))
diff --git a/brainpy/_src/synapses/tests/test_biological_synapses.py b/brainpy/_src/synapses/tests/test_biological_synapses.py
deleted file mode 100644
index 8b25fc26f..000000000
--- a/brainpy/_src/synapses/tests/test_biological_synapses.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# -*- coding: utf-8 -*-
-
-
-import brainpy as bp
-import brainpy.math as bm
-from absl.testing import parameterized
-from brainpy._src.synapses import biological_models
-
-
-class Test_Biological_Synapse(parameterized.TestCase):
- @parameterized.named_parameters(
- {'testcase_name': f'noise_of_{name}', 'synapse': name}
- for name in biological_models.__all__
- )
- def test_all2all_synapse(self, synapse):
- pre_neu = bp.neurons.LIF(5)
- post_neu = bp.neurons.LIF(5)
- syn_model = getattr(biological_models, synapse)
- syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All())
- net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
-
- # 运行模拟
- runner = bp.DSRunner(net,
- monitors=['pre.V', 'syn.g', 'post.V'],
- inputs=('pre.input', 35.))
- runner(10.)
- self.assertTupleEqual(runner.mon['pre.V'].shape, (100, 5))
- self.assertTupleEqual(runner.mon['syn.g'].shape, (100, 5))
- self.assertTupleEqual(runner.mon['post.V'].shape, (100, 5))
-
- @parameterized.named_parameters(
- {'testcase_name': f'noise_of_{name}', 'synapse': name}
- for name in biological_models.__all__
- )
- def test_one2one_synapse(self, synapse):
- pre_neu = bp.neurons.LIF(5)
- post_neu = bp.neurons.LIF(5)
- syn_model = getattr(biological_models, synapse)
- syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One())
- net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
-
- # 运行模拟
- runner = bp.DSRunner(net,
- monitors=['pre.V', 'syn.g', 'post.V'],
- inputs=('pre.input', 35.))
- runner(10.)
- self.assertTupleEqual(runner.mon['pre.V'].shape, (100, 5))
- self.assertTupleEqual(runner.mon['syn.g'].shape, (100, 5))
- self.assertTupleEqual(runner.mon['post.V'].shape, (100, 5))
-
- @parameterized.named_parameters(
- {'testcase_name': f'noise_of_{name}', 'synapse': name}
- for name in biological_models.__all__
- )
- def test_sparse_synapse(self, synapse):
- pre_neu = bp.neurons.LIF(10)
- post_neu = bp.neurons.LIF(10)
- syn_model = getattr(biological_models, synapse)
- syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.5), comp_method='sparse')
- net = bp.Network(pre=pre_neu, syn=syn, post=post_neu)
-
- # 运行模拟
- runner = bp.DSRunner(net,
- monitors=['pre.V', 'syn.g', 'post.V'],
- inputs=('pre.input', 35.))
- runner(10.)
- self.assertTupleEqual(runner.mon['pre.V'].shape, (100, 10))
- self.assertTupleEqual(runner.mon['syn.g'].shape, (100, 10))
- self.assertTupleEqual(runner.mon['post.V'].shape, (100, 10))
diff --git a/brainpy/_src/synapses/tests/test_learning_rule.py b/brainpy/_src/synapses/tests/test_learning_rule.py
deleted file mode 100644
index 8da2651ee..000000000
--- a/brainpy/_src/synapses/tests/test_learning_rule.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# -*- coding: utf-8 -*-
-
-
-import brainpy as bp
-import brainpy.math as bm
-from absl.testing import parameterized
-from brainpy._src.synapses import learning_rules
-
-class Test_learning_rule(parameterized.TestCase):
- def test_learning_rule(self):
- neu1 = bp.neurons.LIF(5)
- neu2 = bp.neurons.LIF(5)
- syn1 = learning_rules.STP(neu1, neu2, bp.connect.All2All(), U=0.1, tau_d=10, tau_f=100.)
- net = bp.Network(pre=neu1, syn=syn1, post=neu2)
-
- runner = bp.DSRunner(net, inputs=[('pre.input', 28.)], monitors=['syn.I', 'syn.u', 'syn.x'])
- runner.run(10.)
- self.assertTupleEqual(runner.mon['syn.I'].shape, (100, 25))
- self.assertTupleEqual(runner.mon['syn.u'].shape, (100, 25))
- self.assertTupleEqual(runner.mon['syn.x'].shape, (100, 25))
\ No newline at end of file
diff --git a/brainpy/_src/synplast/long_term_plasticity.py b/brainpy/_src/synplast/long_term_plasticity.py
deleted file mode 100644
index 40a96afc6..000000000
--- a/brainpy/_src/synplast/long_term_plasticity.py
+++ /dev/null
@@ -1 +0,0 @@
-# -*- coding: utf-8 -*-
diff --git a/brainpy/_src/tests/test_dynsys.py b/brainpy/_src/tests/test_dynsys.py
new file mode 100644
index 000000000..b7a2ebdab
--- /dev/null
+++ b/brainpy/_src/tests/test_dynsys.py
@@ -0,0 +1,40 @@
+
+import brainpy as bp
+
+
+def test1():
+ class A(bp.DynamicalSystem):
+ def update(self, x=None):
+ # print(tdi)
+ print(x)
+
+ A()({}, 10.)
+
+
+def test2():
+ class B(bp.DynamicalSystem):
+ def update(self, tdi, x=None):
+ print(tdi)
+ print(x)
+
+ B()({}, 10.)
+ B()(10.)
+
+
+def test3():
+ class A(bp.DynamicalSystem):
+ def update(self, x=None):
+ # print(tdi)
+ print('A:', x)
+
+ class B(A):
+ def update(self, tdi, x=None):
+ print('B:', tdi, x)
+ super().update(x)
+
+ B()(dict(), 1.)
+ B()(1.)
+
+
+
+
diff --git a/brainpy/_src/tests/test_mixin.py b/brainpy/_src/tests/test_mixin.py
new file mode 100644
index 000000000..fa9a43177
--- /dev/null
+++ b/brainpy/_src/tests/test_mixin.py
@@ -0,0 +1,30 @@
+import brainpy as bp
+
+import unittest
+
+
+class TestParamDesc(unittest.TestCase):
+ def test1(self):
+ a = bp.dyn.Expon(1)
+ self.assertTrue(not isinstance(a, bp.mixin.ParamDesc[bp.dyn.Expon]))
+ self.assertTrue(not isinstance(a, bp.mixin.ParamDesc[bp.DynamicalSystem]))
+
+ def test2(self):
+ a = bp.dyn.Expon.desc(1)
+ self.assertTrue(isinstance(a, bp.mixin.ParamDesc[bp.dyn.Expon]))
+ self.assertTrue(isinstance(a, bp.mixin.ParamDesc[bp.DynamicalSystem]))
+
+
+class TestJointType(unittest.TestCase):
+ def test1(self):
+ T = bp.mixin.JointType[bp.DynamicalSystem]
+ self.assertTrue(isinstance(bp.dnn.Layer(), T))
+
+ T = bp.mixin.JointType[bp.DynamicalSystem, bp.mixin.ParamDesc]
+ self.assertTrue(isinstance(bp.dyn.Expon(1), T))
+
+ def test2(self):
+ T = bp.mixin.JointType[bp.DynamicalSystem, bp.mixin.ParamDesc]
+ self.assertTrue(not isinstance(bp.dyn.Expon(1), bp.mixin.ParamDesc[T]))
+ self.assertTrue(isinstance(bp.dyn.Expon.desc(1), bp.mixin.ParamDesc[T]))
+
diff --git a/brainpy/_src/train/__init__.py b/brainpy/_src/train/__init__.py
index d9a959a57..1d0bdb276 100644
--- a/brainpy/_src/train/__init__.py
+++ b/brainpy/_src/train/__init__.py
@@ -21,5 +21,4 @@
- and others.
"""
-
-
+from . import base, back_propagation, online, offline
diff --git a/brainpy/_src/transform.py b/brainpy/_src/transform.py
index 8ae39c65d..bd64f8a90 100644
--- a/brainpy/_src/transform.py
+++ b/brainpy/_src/transform.py
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
+
import functools
from typing import Union, Optional, Dict, Sequence
@@ -7,16 +8,16 @@
from brainpy import tools, math as bm
from brainpy._src.context import share
+from brainpy._src.dynsys import DynamicalSystem
from brainpy.check import is_float, is_integer
from brainpy.types import PyTree
-from brainpy._src.dynsys import DynamicalSystem, DynamicalSystemNS
__all__ = [
'LoopOverTime',
]
-class LoopOverTime(DynamicalSystemNS):
+class LoopOverTime(DynamicalSystem):
"""Transform a single step :py:class:`~.DynamicalSystem`
into a multiple-step forward propagation :py:class:`~.BrainPyObject`.
diff --git a/brainpy/_src/typing_copy.py b/brainpy/_src/typing_copy.py
new file mode 100644
index 000000000..8e9b25276
--- /dev/null
+++ b/brainpy/_src/typing_copy.py
@@ -0,0 +1,2273 @@
+"""
+The typing module: Support for gradual typing as defined by PEP 484.
+
+At large scale, the structure of the module is following:
+* Imports and exports, all public names should be explicitly added to __all__.
+* Internal helper functions: these should never be used in code outside this module.
+* _SpecialForm and its instances (special forms): Any, NoReturn, ClassVar, Union, Optional
+* Two classes whose instances can be type arguments in addition to types: ForwardRef and TypeVar
+* The core of internal generics API: _GenericAlias and _VariadicGenericAlias, the latter is
+ currently only used by Tuple and Callable. All subscripted types like X[int], Union[int, str],
+ etc., are instances of either of these classes.
+* The public counterpart of the generics API consists of two classes: Generic and Protocol.
+* Public helper functions: get_type_hints, overload, cast, no_type_check,
+ no_type_check_decorator.
+* Generic aliases for collections.abc ABCs and few additional protocols.
+* Special types: NewType, NamedTuple, TypedDict.
+* Wrapper submodules for re and io related types.
+"""
+
+from abc import abstractmethod, ABCMeta
+import collections
+import collections.abc
+import contextlib
+import functools
+import operator
+import re as stdlib_re # Avoid confusion with the re we export.
+import sys
+import types
+from types import WrapperDescriptorType, MethodWrapperType, MethodDescriptorType, GenericAlias
+
+# Please keep __all__ alphabetized within each category.
+__all__ = [
+ # Super-special typing primitives.
+ 'Annotated',
+ 'Any',
+ 'Callable',
+ 'ClassVar',
+ 'Final',
+ 'ForwardRef',
+ 'Generic',
+ 'Literal',
+ 'Optional',
+ 'Protocol',
+ 'Tuple',
+ 'Type',
+ 'TypeVar',
+ 'Union',
+
+ # ABCs (from collections.abc).
+ 'AbstractSet', # collections.abc.Set.
+ 'ByteString',
+ 'Container',
+ 'ContextManager',
+ 'Hashable',
+ 'ItemsView',
+ 'Iterable',
+ 'Iterator',
+ 'KeysView',
+ 'Mapping',
+ 'MappingView',
+ 'MutableMapping',
+ 'MutableSequence',
+ 'MutableSet',
+ 'Sequence',
+ 'Sized',
+ 'ValuesView',
+ 'Awaitable',
+ 'AsyncIterator',
+ 'AsyncIterable',
+ 'Coroutine',
+ 'Collection',
+ 'AsyncGenerator',
+ 'AsyncContextManager',
+
+ # Structural checks, a.k.a. protocols.
+ 'Reversible',
+ 'SupportsAbs',
+ 'SupportsBytes',
+ 'SupportsComplex',
+ 'SupportsFloat',
+ 'SupportsIndex',
+ 'SupportsInt',
+ 'SupportsRound',
+
+ # Concrete collection types.
+ 'ChainMap',
+ 'Counter',
+ 'Deque',
+ 'Dict',
+ 'DefaultDict',
+ 'List',
+ 'OrderedDict',
+ 'Set',
+ 'FrozenSet',
+ 'NamedTuple', # Not really a type.
+ 'TypedDict', # Not really a type.
+ 'Generator',
+
+ # Other concrete types.
+ 'BinaryIO',
+ 'IO',
+ 'Match',
+ 'Pattern',
+ 'TextIO',
+
+ # One-off things.
+ 'AnyStr',
+ 'cast',
+ 'final',
+ 'get_args',
+ 'get_origin',
+ 'get_type_hints',
+ 'NewType',
+ 'no_type_check',
+ 'no_type_check_decorator',
+ 'NoReturn',
+ 'overload',
+ 'runtime_checkable',
+ 'Text',
+ 'TYPE_CHECKING',
+]
+
+
+# The pseudo-submodules 're' and 'io' are part of the public
+# namespace, but excluded from __all__ because they might stomp on
+# legitimate imports of those modules.
+
+
+def _type_convert(arg, module=None, *, allow_special_forms=False):
+ """For converting None to type(None), and strings to ForwardRef."""
+ if arg is None:
+ return type(None)
+ if isinstance(arg, str):
+ return ForwardRef(arg, module=module, is_class=allow_special_forms)
+ return arg
+
+
+def _type_check(arg, msg, is_argument=True, module=None, *, allow_special_forms=False):
+ """Check that the argument is a type, and return it (internal helper).
+
+ As a special case, accept None and return type(None) instead. Also wrap strings
+ into ForwardRef instances. Consider several corner cases, for example plain
+ special forms like Union are not valid, while Union[int, str] is OK, etc.
+ The msg argument is a human-readable error message, e.g::
+
+ "Union[arg, ...]: arg should be a type."
+
+ We append the repr() of the actual value (truncated to 100 chars).
+ """
+ invalid_generic_forms = (Generic, Protocol)
+ if not allow_special_forms:
+ invalid_generic_forms += (ClassVar,)
+ if is_argument:
+ invalid_generic_forms += (Final,)
+
+ arg = _type_convert(arg, module=module, allow_special_forms=allow_special_forms)
+ if (isinstance(arg, _GenericAlias) and
+ arg.__origin__ in invalid_generic_forms):
+ raise TypeError(f"{arg} is not valid as type argument")
+ if arg in (Any, NoReturn, Final):
+ return arg
+ if isinstance(arg, _SpecialForm) or arg in (Generic, Protocol):
+ raise TypeError(f"Plain {arg} is not valid as type argument")
+ if isinstance(arg, (type, TypeVar, ForwardRef)):
+ return arg
+ if not callable(arg):
+ raise TypeError(f"{msg} Got {arg!r:.100}.")
+ return arg
+
+
+def _type_repr(obj):
+ """Return the repr() of an object, special-casing types (internal helper).
+
+ If obj is a type, we return a shorter version than the default
+ type.__repr__, based on the module and qualified name, which is
+ typically enough to uniquely identify a type. For everything
+ else, we fall back on repr(obj).
+ """
+ if isinstance(obj, types.GenericAlias):
+ return repr(obj)
+ if isinstance(obj, type):
+ if obj.__module__ == 'builtins':
+ return obj.__qualname__
+ return f'{obj.__module__}.{obj.__qualname__}'
+ if obj is ...:
+ return ('...')
+ if isinstance(obj, types.FunctionType):
+ return obj.__name__
+ return repr(obj)
+
+
+def _collect_type_vars(types):
+ """Collect all type variable contained in types in order of
+ first appearance (lexicographic order). For example::
+
+ _collect_type_vars((T, List[S, T])) == (T, S)
+ """
+ tvars = []
+ for t in types:
+ if isinstance(t, TypeVar) and t not in tvars:
+ tvars.append(t)
+ if isinstance(t, (_GenericAlias, GenericAlias)):
+ tvars.extend([t for t in t.__parameters__ if t not in tvars])
+ return tuple(tvars)
+
+
+def _check_generic(cls, parameters, elen):
+ """Check correct count for parameters of a generic cls (internal helper).
+ This gives a nice error message in case of count mismatch.
+ """
+ if not elen:
+ raise TypeError(f"{cls} is not a generic class")
+ alen = len(parameters)
+ if alen != elen:
+ raise TypeError(f"Too {'many' if alen > elen else 'few'} parameters for {cls};"
+ f" actual {alen}, expected {elen}")
+
+
+def _deduplicate(params):
+ # Weed out strict duplicates, preserving the first of each occurrence.
+ all_params = set(params)
+ if len(all_params) < len(params):
+ new_params = []
+ for t in params:
+ if t in all_params:
+ new_params.append(t)
+ all_params.remove(t)
+ params = new_params
+ assert not all_params, all_params
+ return params
+
+
+def _remove_dups_flatten(parameters):
+ """An internal helper for Union creation and substitution: flatten Unions
+ among parameters, then remove duplicates.
+ """
+ # Flatten out Union[Union[...], ...].
+ params = []
+ for p in parameters:
+ if isinstance(p, _UnionGenericAlias):
+ params.extend(p.__args__)
+ elif isinstance(p, tuple) and len(p) > 0 and p[0] is Union:
+ params.extend(p[1:])
+ else:
+ params.append(p)
+
+ return tuple(_deduplicate(params))
+
+
+def _flatten_literal_params(parameters):
+ """An internal helper for Literal creation: flatten Literals among parameters"""
+ params = []
+ for p in parameters:
+ if isinstance(p, _LiteralGenericAlias):
+ params.extend(p.__args__)
+ else:
+ params.append(p)
+ return tuple(params)
+
+
+_cleanups = []
+
+
+def _tp_cache(func=None, /, *, typed=False):
+ """Internal wrapper caching __getitem__ of generic types with a fallback to
+ original function for non-hashable arguments.
+ """
+
+ def decorator(func):
+ cached = functools.lru_cache(typed=typed)(func)
+ _cleanups.append(cached.cache_clear)
+
+ @functools.wraps(func)
+ def inner(*args, **kwds):
+ try:
+ return cached(*args, **kwds)
+ except TypeError:
+ pass # All real errors (not unhashable args) are raised below.
+ return func(*args, **kwds)
+
+ return inner
+
+ if func is not None:
+ return decorator(func)
+
+ return decorator
+
+
+def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
+ """Evaluate all forward references in the given type t.
+ For use of globalns and localns see the docstring for get_type_hints().
+ recursive_guard is used to prevent infinite recursion with a recursive
+ ForwardRef.
+ """
+ if isinstance(t, ForwardRef):
+ return t._evaluate(globalns, localns, recursive_guard)
+ if isinstance(t, (_GenericAlias, GenericAlias)):
+ ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
+ if ev_args == t.__args__:
+ return t
+ if isinstance(t, GenericAlias):
+ return GenericAlias(t.__origin__, ev_args)
+ else:
+ return t.copy_with(ev_args)
+ return t
+
+
+class _Final:
+ """Mixin to prohibit subclassing"""
+
+ __slots__ = ('__weakref__',)
+
+ def __init_subclass__(self, /, *args, **kwds):
+ if '_root' not in kwds:
+ raise TypeError("Cannot subclass special typing classes")
+
+
+class _Immutable:
+ """Mixin to indicate that object should not be copied."""
+ __slots__ = ()
+
+ def __copy__(self):
+ return self
+
+ def __deepcopy__(self, memo):
+ return self
+
+
+# Internal indicator of special typing constructs.
+# See __doc__ instance attribute for specific docs.
+class _SpecialForm(_Final, _root=True):
+ __slots__ = ('_name', '__doc__', '_getitem')
+
+ def __init__(self, getitem):
+ self._getitem = getitem
+ self._name = getitem.__name__
+ self.__doc__ = getitem.__doc__
+
+ def __mro_entries__(self, bases):
+ raise TypeError(f"Cannot subclass {self!r}")
+
+ def __repr__(self):
+ return 'typing.' + self._name
+
+ def __reduce__(self):
+ return self._name
+
+ def __call__(self, *args, **kwds):
+ raise TypeError(f"Cannot instantiate {self!r}")
+
+ def __instancecheck__(self, obj):
+ raise TypeError(f"{self} cannot be used with isinstance()")
+
+ def __subclasscheck__(self, cls):
+ raise TypeError(f"{self} cannot be used with issubclass()")
+
+ @_tp_cache
+ def __getitem__(self, parameters):
+ return self._getitem(self, parameters)
+
+
+class _LiteralSpecialForm(_SpecialForm, _root=True):
+ def __getitem__(self, parameters):
+ if not isinstance(parameters, tuple):
+ parameters = (parameters,)
+ return self._getitem(self, *parameters)
+
+
+@_SpecialForm
+def Any(self, parameters):
+ """Special type indicating an unconstrained type.
+
+ - Any is compatible with every type.
+ - Any assumed to have all methods.
+ - All values assumed to be instances of Any.
+
+ Note that all the above statements are true from the point of view of
+ static type checkers. At runtime, Any should not be used with instance
+ or class checks.
+ """
+ raise TypeError(f"{self} is not subscriptable")
+
+
+@_SpecialForm
+def NoReturn(self, parameters):
+ """Special type indicating functions that never return.
+ Example::
+
+ from typing import NoReturn
+
+ def stop() -> NoReturn:
+ raise Exception('no way')
+
+ This type is invalid in other positions, e.g., ``List[NoReturn]``
+ will fail in static type checkers.
+ """
+ raise TypeError(f"{self} is not subscriptable")
+
+
+@_SpecialForm
+def ClassVar(self, parameters):
+ """Special type construct to mark class variables.
+
+ An annotation wrapped in ClassVar indicates that a given
+ attribute is intended to be used as a class variable and
+ should not be set on instances of that class. Usage::
+
+ class Starship:
+ stats: ClassVar[Dict[str, int]] = {} # class variable
+ damage: int = 10 # instance variable
+
+ ClassVar accepts only types and cannot be further subscribed.
+
+ Note that ClassVar is not a class itself, and should not
+ be used with isinstance() or issubclass().
+ """
+ item = _type_check(parameters, f'{self} accepts only single type.')
+ return _GenericAlias(self, (item,))
+
+
+@_SpecialForm
+def Final(self, parameters):
+ """Special typing construct to indicate final names to type checkers.
+
+ A final name cannot be re-assigned or overridden in a subclass.
+ For example:
+
+ MAX_SIZE: Final = 9000
+ MAX_SIZE += 1 # Error reported by type checker
+
+ class Connection:
+ TIMEOUT: Final[int] = 10
+
+ class FastConnector(Connection):
+ TIMEOUT = 1 # Error reported by type checker
+
+ There is no runtime checking of these properties.
+ """
+ item = _type_check(parameters, f'{self} accepts only single type.')
+ return _GenericAlias(self, (item,))
+
+
+@_SpecialForm
+def Union(self, parameters):
+ """Union type; Union[X, Y] means either X or Y.
+
+ To define a union, use e.g. Union[int, str]. Details:
+ - The arguments must be types and there must be at least one.
+ - None as an argument is a special case and is replaced by
+ type(None).
+ - Unions of unions are flattened, e.g.::
+
+ Union[Union[int, str], float] == Union[int, str, float]
+
+ - Unions of a single argument vanish, e.g.::
+
+ Union[int] == int # The constructor actually returns int
+
+ - Redundant arguments are skipped, e.g.::
+
+ Union[int, str, int] == Union[int, str]
+
+ - When comparing unions, the argument order is ignored, e.g.::
+
+ Union[int, str] == Union[str, int]
+
+ - You cannot subclass or instantiate a union.
+ - You can use Optional[X] as a shorthand for Union[X, None].
+ """
+ if parameters == ():
+ raise TypeError("Cannot take a Union of no types.")
+ if not isinstance(parameters, tuple):
+ parameters = (parameters,)
+ msg = "Union[arg, ...]: each arg must be a type."
+ parameters = tuple(_type_check(p, msg) for p in parameters)
+ parameters = _remove_dups_flatten(parameters)
+ if len(parameters) == 1:
+ return parameters[0]
+ return _UnionGenericAlias(self, parameters)
+
+
+@_SpecialForm
+def Optional(self, parameters):
+ """Optional type.
+
+ Optional[X] is equivalent to Union[X, None].
+ """
+ arg = _type_check(parameters, f"{self} requires a single type.")
+ return Union[arg, type(None)]
+
+
+@_LiteralSpecialForm
+@_tp_cache(typed=True)
+def Literal(self, *parameters):
+ """Special typing form to define literal types (a.k.a. value types).
+
+ This form can be used to indicate to type checkers that the corresponding
+ variable or function parameter has a value equivalent to the provided
+ literal (or one of several literals):
+
+ def validate_simple(data: Any) -> Literal[True]: # always returns True
+ ...
+
+ MODE = Literal['r', 'rb', 'w', 'wb']
+ def open_helper(file: str, mode: MODE) -> str:
+ ...
+
+ open_helper('/some/path', 'r') # Passes type check
+ open_helper('/other/path', 'typo') # Error in type checker
+
+ Literal[...] cannot be subclassed. At runtime, an arbitrary value
+ is allowed as type argument to Literal[...], but type checkers may
+ impose restrictions.
+ """
+ # There is no '_type_check' call because arguments to Literal[...] are
+ # values, not types.
+ parameters = _flatten_literal_params(parameters)
+
+ try:
+ parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
+ except TypeError: # unhashable parameters
+ pass
+
+ return _LiteralGenericAlias(self, parameters)
+
+
+class ForwardRef(_Final, _root=True):
+ """Internal wrapper to hold a forward reference."""
+
+ __slots__ = ('__forward_arg__', '__forward_code__',
+ '__forward_evaluated__', '__forward_value__',
+ '__forward_is_argument__', '__forward_is_class__',
+ '__forward_module__')
+
+ def __init__(self, arg, is_argument=True, module=None, *, is_class=False):
+ if not isinstance(arg, str):
+ raise TypeError(f"Forward reference must be a string -- got {arg!r}")
+ try:
+ code = compile(arg, '', 'eval')
+ except SyntaxError:
+ raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}")
+ self.__forward_arg__ = arg
+ self.__forward_code__ = code
+ self.__forward_evaluated__ = False
+ self.__forward_value__ = None
+ self.__forward_is_argument__ = is_argument
+ self.__forward_is_class__ = is_class
+ self.__forward_module__ = module
+
+ def _evaluate(self, globalns, localns, recursive_guard):
+ if self.__forward_arg__ in recursive_guard:
+ return self
+ if not self.__forward_evaluated__ or localns is not globalns:
+ if globalns is None and localns is None:
+ globalns = localns = {}
+ elif globalns is None:
+ globalns = localns
+ elif localns is None:
+ localns = globalns
+ if self.__forward_module__ is not None:
+ globalns = getattr(
+ sys.modules.get(self.__forward_module__, None), '__dict__', globalns
+ )
+ type_ = _type_check(
+ eval(self.__forward_code__, globalns, localns),
+ "Forward references must evaluate to types.",
+ is_argument=self.__forward_is_argument__,
+ allow_special_forms=self.__forward_is_class__,
+ )
+ self.__forward_value__ = _eval_type(
+ type_, globalns, localns, recursive_guard | {self.__forward_arg__}
+ )
+ self.__forward_evaluated__ = True
+ return self.__forward_value__
+
+ def __eq__(self, other):
+ if not isinstance(other, ForwardRef):
+ return NotImplemented
+ if self.__forward_evaluated__ and other.__forward_evaluated__:
+ return (self.__forward_arg__ == other.__forward_arg__ and
+ self.__forward_value__ == other.__forward_value__)
+ return (self.__forward_arg__ == other.__forward_arg__ and
+ self.__forward_module__ == other.__forward_module__)
+
+ def __hash__(self):
+ return hash((self.__forward_arg__, self.__forward_module__))
+
+ def __repr__(self):
+ return f'ForwardRef({self.__forward_arg__!r})'
+
+
+class TypeVar(_Final, _Immutable, _root=True):
+ """Type variable.
+
+ Usage::
+
+ T = TypeVar('T') # Can be anything
+ A = TypeVar('A', str, bytes) # Must be str or bytes
+
+ Type variables exist primarily for the benefit of static type
+ checkers. They serve as the parameters for generic types as well
+ as for generic function definitions. See class Generic for more
+ information on generic types. Generic functions work as follows:
+
+ def repeat(x: T, n: int) -> List[T]:
+ '''Return a list containing n references to x.'''
+ return [x]*n
+
+ def longest(x: A, y: A) -> A:
+ '''Return the longest of two strings.'''
+ return x if len(x) >= len(y) else y
+
+ The latter example's signature is essentially the overloading
+ of (str, str) -> str and (bytes, bytes) -> bytes. Also note
+ that if the arguments are instances of some subclass of str,
+ the return type is still plain str.
+
+ At runtime, isinstance(x, T) and issubclass(C, T) will raise TypeError.
+
+ Type variables defined with covariant=True or contravariant=True
+ can be used to declare covariant or contravariant generic types.
+ See PEP 484 for more details. By default generic types are invariant
+ in all type variables.
+
+ Type variables can be introspected. e.g.:
+
+ T.__name__ == 'T'
+ T.__constraints__ == ()
+ T.__covariant__ == False
+ T.__contravariant__ = False
+ A.__constraints__ == (str, bytes)
+
+ Note that only type variables defined in global scope can be pickled.
+ """
+
+ __slots__ = ('__name__', '__bound__', '__constraints__',
+ '__covariant__', '__contravariant__', '__dict__')
+
+ def __init__(self, name, *constraints, bound=None,
+ covariant=False, contravariant=False):
+ self.__name__ = name
+ if covariant and contravariant:
+ raise ValueError("Bivariant types are not supported.")
+ self.__covariant__ = bool(covariant)
+ self.__contravariant__ = bool(contravariant)
+ if constraints and bound is not None:
+ raise TypeError("Constraints cannot be combined with bound=...")
+ if constraints and len(constraints) == 1:
+ raise TypeError("A single constraint is not allowed")
+ msg = "TypeVar(name, constraint, ...): constraints must be types."
+ self.__constraints__ = tuple(_type_check(t, msg) for t in constraints)
+ if bound:
+ self.__bound__ = _type_check(bound, "Bound must be a type.")
+ else:
+ self.__bound__ = None
+ try:
+ def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') # for pickling
+ except (AttributeError, ValueError):
+ def_mod = None
+ if def_mod != 'typing':
+ self.__module__ = def_mod
+
+ def __repr__(self):
+ if self.__covariant__:
+ prefix = '+'
+ elif self.__contravariant__:
+ prefix = '-'
+ else:
+ prefix = '~'
+ return prefix + self.__name__
+
+ def __reduce__(self):
+ return self.__name__
+
+
+def _is_dunder(attr):
+ return attr.startswith('__') and attr.endswith('__')
+
+
+class _BaseGenericAlias(_Final, _root=True):
+ """The central part of internal API.
+
+ This represents a generic version of type 'origin' with type arguments 'params'.
+ There are two kind of these aliases: user defined and special. The special ones
+ are wrappers around builtin collections and ABCs in collections.abc. These must
+ have 'name' always set. If 'inst' is False, then the alias can't be instantiated,
+ this is used by e.g. typing.List and typing.Dict.
+ """
+
+ def __init__(self, origin, *, inst=True, name=None):
+ self._inst = inst
+ self._name = name
+ self.__origin__ = origin
+ self.__slots__ = None # This is not documented.
+
+ def __call__(self, *args, **kwargs):
+ if not self._inst:
+ raise TypeError(f"Type {self._name} cannot be instantiated; "
+ f"use {self.__origin__.__name__}() instead")
+ result = self.__origin__(*args, **kwargs)
+ try:
+ result.__orig_class__ = self
+ except AttributeError:
+ pass
+ return result
+
+ def __mro_entries__(self, bases):
+ res = []
+ if self.__origin__ not in bases:
+ res.append(self.__origin__)
+ i = bases.index(self)
+ for b in bases[i + 1:]:
+ if isinstance(b, _BaseGenericAlias) or issubclass(b, Generic):
+ break
+ else:
+ res.append(Generic)
+ return tuple(res)
+
+ def __getattr__(self, attr):
+ # We are careful for copy and pickle.
+ # Also for simplicity we don't relay any dunder names
+ if '__origin__' in self.__dict__ and not _is_dunder(attr):
+ return getattr(self.__origin__, attr)
+ raise AttributeError(attr)
+
+ def __setattr__(self, attr, val):
+ if _is_dunder(attr) or attr in ('_name', '_inst', '_nparams'):
+ super().__setattr__(attr, val)
+ else:
+ setattr(self.__origin__, attr, val)
+
+ def __instancecheck__(self, obj):
+ return self.__subclasscheck__(type(obj))
+
+ def __subclasscheck__(self, cls):
+ raise TypeError("Subscripted generics cannot be used with"
+ " class and instance checks")
+
+
+# Special typing constructs Union, Optional, Generic, Callable and Tuple
+# use three special attributes for internal bookkeeping of generic types:
+# * __parameters__ is a tuple of unique free type parameters of a generic
+# type, for example, Dict[T, T].__parameters__ == (T,);
+# * __origin__ keeps a reference to a type that was subscripted,
+# e.g., Union[T, int].__origin__ == Union, or the non-generic version of
+# the type.
+# * __args__ is a tuple of all arguments used in subscripting,
+# e.g., Dict[T, int].__args__ == (T, int).
+
+
+class _GenericAlias(_BaseGenericAlias, _root=True):
+ def __init__(self, origin, params, *, inst=True, name=None):
+ super().__init__(origin, inst=inst, name=name)
+ if not isinstance(params, tuple):
+ params = (params,)
+ self.__args__ = tuple(... if a is _TypingEllipsis else
+ () if a is _TypingEmpty else
+ a for a in params)
+ self.__parameters__ = _collect_type_vars(params)
+ if not name:
+ self.__module__ = origin.__module__
+
+ def __eq__(self, other):
+ if not isinstance(other, _GenericAlias):
+ return NotImplemented
+ return (self.__origin__ == other.__origin__
+ and self.__args__ == other.__args__)
+
+ def __hash__(self):
+ return hash((self.__origin__, self.__args__))
+
+ @_tp_cache
+ def __getitem__(self, params):
+ if self.__origin__ in (Generic, Protocol):
+ # Can't subscript Generic[...] or Protocol[...].
+ raise TypeError(f"Cannot subscript already-subscripted {self}")
+ if not isinstance(params, tuple):
+ params = (params,)
+ msg = "Parameters to generic types must be types."
+ params = tuple(_type_check(p, msg) for p in params)
+ _check_generic(self, params, len(self.__parameters__))
+
+ subst = dict(zip(self.__parameters__, params))
+ new_args = []
+ for arg in self.__args__:
+ if isinstance(arg, TypeVar):
+ arg = subst[arg]
+ elif isinstance(arg, (_GenericAlias, GenericAlias)):
+ subparams = arg.__parameters__
+ if subparams:
+ subargs = tuple(subst[x] for x in subparams)
+ arg = arg[subargs]
+ new_args.append(arg)
+ return self.copy_with(tuple(new_args))
+
+ def copy_with(self, params):
+ return self.__class__(self.__origin__, params, name=self._name, inst=self._inst)
+
+ def __repr__(self):
+ if self._name:
+ name = 'typing.' + self._name
+ else:
+ name = _type_repr(self.__origin__)
+ args = ", ".join([_type_repr(a) for a in self.__args__])
+ return f'{name}[{args}]'
+
+ def __reduce__(self):
+ if self._name:
+ origin = globals()[self._name]
+ else:
+ origin = self.__origin__
+ args = tuple(self.__args__)
+ if len(args) == 1 and not isinstance(args[0], tuple):
+ args, = args
+ return operator.getitem, (origin, args)
+
+ def __mro_entries__(self, bases):
+ if self._name: # generic version of an ABC or built-in class
+ return super().__mro_entries__(bases)
+ if self.__origin__ is Generic:
+ if Protocol in bases:
+ return ()
+ i = bases.index(self)
+ for b in bases[i + 1:]:
+ if isinstance(b, _BaseGenericAlias) and b is not self:
+ return ()
+ return (self.__origin__,)
+
+
+# _nparams is the number of accepted parameters, e.g. 0 for Hashable,
+# 1 for List and 2 for Dict. It may be -1 if variable number of
+# parameters are accepted (needs custom __getitem__).
+
+class _SpecialGenericAlias(_BaseGenericAlias, _root=True):
+ def __init__(self, origin, nparams, *, inst=True, name=None):
+ if name is None:
+ name = origin.__name__
+ super().__init__(origin, inst=inst, name=name)
+ self._nparams = nparams
+ if origin.__module__ == 'builtins':
+ self.__doc__ = f'A generic version of {origin.__qualname__}.'
+ else:
+ self.__doc__ = f'A generic version of {origin.__module__}.{origin.__qualname__}.'
+
+ @_tp_cache
+ def __getitem__(self, params):
+ if not isinstance(params, tuple):
+ params = (params,)
+ msg = "Parameters to generic types must be types."
+ params = tuple(_type_check(p, msg) for p in params)
+ _check_generic(self, params, self._nparams)
+ return self.copy_with(params)
+
+ def copy_with(self, params):
+ return _GenericAlias(self.__origin__, params,
+ name=self._name, inst=self._inst)
+
+ def __repr__(self):
+ return 'typing.' + self._name
+
+ def __subclasscheck__(self, cls):
+ if isinstance(cls, _SpecialGenericAlias):
+ return issubclass(cls.__origin__, self.__origin__)
+ if not isinstance(cls, _GenericAlias):
+ return issubclass(cls, self.__origin__)
+ return super().__subclasscheck__(cls)
+
+ def __reduce__(self):
+ return self._name
+
+
+class _CallableGenericAlias(_GenericAlias, _root=True):
+ def __repr__(self):
+ assert self._name == 'Callable'
+ if len(self.__args__) == 2 and self.__args__[0] is Ellipsis:
+ return super().__repr__()
+ return (f'typing.Callable'
+ f'[[{", ".join([_type_repr(a) for a in self.__args__[:-1]])}], '
+ f'{_type_repr(self.__args__[-1])}]')
+
+ def __reduce__(self):
+ args = self.__args__
+ if not (len(args) == 2 and args[0] is ...):
+ args = list(args[:-1]), args[-1]
+ return operator.getitem, (Callable, args)
+
+
+class _CallableType(_SpecialGenericAlias, _root=True):
+ def copy_with(self, params):
+ return _CallableGenericAlias(self.__origin__, params,
+ name=self._name, inst=self._inst)
+
+ def __getitem__(self, params):
+ if not isinstance(params, tuple) or len(params) != 2:
+ raise TypeError("Callable must be used as "
+ "Callable[[arg, ...], result].")
+ args, result = params
+ # This relaxes what args can be on purpose to allow things like
+ # PEP 612 ParamSpec. Responsibility for whether a user is using
+ # Callable[...] properly is deferred to static type checkers.
+ if isinstance(args, list):
+ params = (tuple(args), result)
+ else:
+ params = (args, result)
+ return self.__getitem_inner__(params)
+
+ @_tp_cache
+ def __getitem_inner__(self, params):
+ args, result = params
+ msg = "Callable[args, result]: result must be a type."
+ result = _type_check(result, msg)
+ if args is Ellipsis:
+ return self.copy_with((_TypingEllipsis, result))
+ if not isinstance(args, tuple):
+ args = (args,)
+ args = tuple(_type_convert(arg) for arg in args)
+ params = args + (result,)
+ return self.copy_with(params)
+
+
+class _TupleType(_SpecialGenericAlias, _root=True):
+ @_tp_cache
+ def __getitem__(self, params):
+ if params == ():
+ return self.copy_with((_TypingEmpty,))
+ if not isinstance(params, tuple):
+ params = (params,)
+ if len(params) == 2 and params[1] is ...:
+ msg = "Tuple[t, ...]: t must be a type."
+ p = _type_check(params[0], msg)
+ return self.copy_with((p, _TypingEllipsis))
+ msg = "Tuple[t0, t1, ...]: each t must be a type."
+ params = tuple(_type_check(p, msg) for p in params)
+ return self.copy_with(params)
+
+
+class _UnionGenericAlias(_GenericAlias, _root=True):
+ def copy_with(self, params):
+ return Union[params]
+
+ def __eq__(self, other):
+ if not isinstance(other, _UnionGenericAlias):
+ return NotImplemented
+ return set(self.__args__) == set(other.__args__)
+
+ def __hash__(self):
+ return hash(frozenset(self.__args__))
+
+ def __repr__(self):
+ args = self.__args__
+ if len(args) == 2:
+ if args[0] is type(None):
+ return f'typing.Optional[{_type_repr(args[1])}]'
+ elif args[1] is type(None):
+ return f'typing.Optional[{_type_repr(args[0])}]'
+ return super().__repr__()
+
+
+def _value_and_type_iter(parameters):
+ return ((p, type(p)) for p in parameters)
+
+
+class _LiteralGenericAlias(_GenericAlias, _root=True):
+
+ def __eq__(self, other):
+ if not isinstance(other, _LiteralGenericAlias):
+ return NotImplemented
+
+ return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))
+
+ def __hash__(self):
+ return hash(frozenset(_value_and_type_iter(self.__args__)))
+
+
+class Generic:
+ """Abstract base class for generic types.
+
+ A generic type is typically declared by inheriting from
+ this class parameterized with one or more type variables.
+ For example, a generic mapping type might be defined as::
+
+ class Mapping(Generic[KT, VT]):
+ def __getitem__(self, key: KT) -> VT:
+ ...
+ # Etc.
+
+ This class can then be used as follows::
+
+ def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT:
+ try:
+ return mapping[key]
+ except KeyError:
+ return default
+ """
+ __slots__ = ()
+ _is_protocol = False
+
+ @_tp_cache
+ def __class_getitem__(cls, params):
+ if not isinstance(params, tuple):
+ params = (params,)
+ if not params and cls is not Tuple:
+ raise TypeError(
+ f"Parameter list to {cls.__qualname__}[...] cannot be empty")
+ msg = "Parameters to generic types must be types."
+ params = tuple(_type_check(p, msg) for p in params)
+ if cls in (Generic, Protocol):
+ # Generic and Protocol can only be subscripted with unique type variables.
+ if not all(isinstance(p, TypeVar) for p in params):
+ raise TypeError(
+ f"Parameters to {cls.__name__}[...] must all be type variables")
+ if len(set(params)) != len(params):
+ raise TypeError(
+ f"Parameters to {cls.__name__}[...] must all be unique")
+ else:
+ # Subscripting a regular Generic subclass.
+ _check_generic(cls, params, len(cls.__parameters__))
+ return _GenericAlias(cls, params)
+
+ def __init_subclass__(cls, *args, **kwargs):
+ super().__init_subclass__(*args, **kwargs)
+ tvars = []
+ if '__orig_bases__' in cls.__dict__:
+ error = Generic in cls.__orig_bases__
+ else:
+ error = Generic in cls.__bases__ and cls.__name__ != 'Protocol'
+ if error:
+ raise TypeError("Cannot inherit from plain Generic")
+ if '__orig_bases__' in cls.__dict__:
+ tvars = _collect_type_vars(cls.__orig_bases__)
+ # Look for Generic[T1, ..., Tn].
+ # If found, tvars must be a subset of it.
+ # If not found, tvars is it.
+ # Also check for and reject plain Generic,
+ # and reject multiple Generic[...].
+ gvars = None
+ for base in cls.__orig_bases__:
+ if (isinstance(base, _GenericAlias) and
+ base.__origin__ is Generic):
+ if gvars is not None:
+ raise TypeError(
+ "Cannot inherit from Generic[...] multiple types.")
+ gvars = base.__parameters__
+ if gvars is not None:
+ tvarset = set(tvars)
+ gvarset = set(gvars)
+ if not tvarset <= gvarset:
+ s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
+ s_args = ', '.join(str(g) for g in gvars)
+ raise TypeError(f"Some type variables ({s_vars}) are"
+ f" not listed in Generic[{s_args}]")
+ tvars = gvars
+ cls.__parameters__ = tuple(tvars)
+
+
+class _TypingEmpty:
+ """Internal placeholder for () or []. Used by TupleMeta and CallableMeta
+ to allow empty list/tuple in specific places, without allowing them
+ to sneak in where prohibited.
+ """
+
+
+class _TypingEllipsis:
+ """Internal placeholder for ... (ellipsis)."""
+
+
+_TYPING_INTERNALS = ['__parameters__', '__orig_bases__', '__orig_class__',
+ '_is_protocol', '_is_runtime_protocol']
+
+_SPECIAL_NAMES = ['__abstractmethods__', '__annotations__', '__dict__', '__doc__',
+ '__init__', '__module__', '__new__', '__slots__',
+ '__subclasshook__', '__weakref__', '__class_getitem__']
+
+# These special attributes will be not collected as protocol members.
+EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS + _SPECIAL_NAMES + ['_MutableMapping__marker']
+
+
+def _get_protocol_attrs(cls):
+ """Collect protocol members from a protocol class objects.
+
+ This includes names actually defined in the class dictionary, as well
+ as names that appear in annotations. Special names (above) are skipped.
+ """
+ attrs = set()
+ for base in cls.__mro__[:-1]: # without object
+ if base.__name__ in ('Protocol', 'Generic'):
+ continue
+ annotations = getattr(base, '__annotations__', {})
+ for attr in list(base.__dict__.keys()) + list(annotations.keys()):
+ if not attr.startswith('_abc_') and attr not in EXCLUDED_ATTRIBUTES:
+ attrs.add(attr)
+ return attrs
+
+
+def _is_callable_members_only(cls):
+ # PEP 544 prohibits using issubclass() with protocols that have non-method members.
+ return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))
+
+
+def _no_init_or_replace_init(self, *args, **kwargs):
+ cls = type(self)
+
+ if cls._is_protocol:
+ raise TypeError('Protocols cannot be instantiated')
+
+ # Already using a custom `__init__`. No need to calculate correct
+ # `__init__` to call. This can lead to RecursionError. See bpo-45121.
+ if cls.__init__ is not _no_init_or_replace_init:
+ return
+
+ # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
+ # The first instantiation of the subclass will call `_no_init_or_replace_init` which
+ # searches for a proper new `__init__` in the MRO. The new `__init__`
+ # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
+ # instantiation of the protocol subclass will thus use the new
+ # `__init__` and no longer call `_no_init_or_replace_init`.
+ for base in cls.__mro__:
+ init = base.__dict__.get('__init__', _no_init_or_replace_init)
+ if init is not _no_init_or_replace_init:
+ cls.__init__ = init
+ break
+ else:
+ # should not happen
+ cls.__init__ = object.__init__
+
+ cls.__init__(self, *args, **kwargs)
+
+
+def _allow_reckless_class_cheks():
+ """Allow instance and class checks for special stdlib modules.
+
+ The abc and functools modules indiscriminately call isinstance() and
+ issubclass() on the whole MRO of a user class, which may contain protocols.
+ """
+ try:
+ return sys._getframe(3).f_globals['__name__'] in ['abc', 'functools']
+ except (AttributeError, ValueError): # For platforms without _getframe().
+ return True
+
+
+_PROTO_WHITELIST = {
+ 'collections.abc': [
+ 'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable',
+ 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible',
+ ],
+ 'contextlib': ['AbstractContextManager', 'AbstractAsyncContextManager'],
+}
+
+
+class _ProtocolMeta(ABCMeta):
+ # This metaclass is really unfortunate and exists only because of
+ # the lack of __instancehook__.
+ def __instancecheck__(cls, instance):
+ # We need this method for situations where attributes are
+ # assigned in __init__.
+ if ((not getattr(cls, '_is_protocol', False) or
+ _is_callable_members_only(cls)) and
+ issubclass(instance.__class__, cls)):
+ return True
+ if cls._is_protocol:
+ if all(hasattr(instance, attr) and
+ # All *methods* can be blocked by setting them to None.
+ (not callable(getattr(cls, attr, None)) or
+ getattr(instance, attr) is not None)
+ for attr in _get_protocol_attrs(cls)):
+ return True
+ return super().__instancecheck__(instance)
+
+
+class Protocol(Generic, metaclass=_ProtocolMeta):
+ """Base class for protocol classes.
+
+ Protocol classes are defined as::
+
+ class Proto(Protocol):
+ def meth(self) -> int:
+ ...
+
+ Such classes are primarily used with static type checkers that recognize
+ structural subtyping (static duck-typing), for example::
+
+ class C:
+ def meth(self) -> int:
+ return 0
+
+ def func(x: Proto) -> int:
+ return x.meth()
+
+ func(C()) # Passes static type check
+
+ See PEP 544 for details. Protocol classes decorated with
+ @typing.runtime_checkable act as simple-minded runtime protocols that check
+ only the presence of given attributes, ignoring their type signatures.
+ Protocol classes can be generic, they are defined as::
+
+ class GenProto(Protocol[T]):
+ def meth(self) -> T:
+ ...
+ """
+ __slots__ = ()
+ _is_protocol = True
+ _is_runtime_protocol = False
+
+ def __init_subclass__(cls, *args, **kwargs):
+ super().__init_subclass__(*args, **kwargs)
+
+ # Determine if this is a protocol or a concrete subclass.
+ if not cls.__dict__.get('_is_protocol', False):
+ cls._is_protocol = any(b is Protocol for b in cls.__bases__)
+
+ # Set (or override) the protocol subclass hook.
+ def _proto_hook(other):
+ if not cls.__dict__.get('_is_protocol', False):
+ return NotImplemented
+
+ # First, perform various sanity checks.
+ if not getattr(cls, '_is_runtime_protocol', False):
+ if _allow_reckless_class_cheks():
+ return NotImplemented
+ raise TypeError("Instance and class checks can only be used with"
+ " @runtime_checkable protocols")
+ if not _is_callable_members_only(cls):
+ if _allow_reckless_class_cheks():
+ return NotImplemented
+ raise TypeError("Protocols with non-method members"
+ " don't support issubclass()")
+ if not isinstance(other, type):
+ # Same error message as for issubclass(1, int).
+ raise TypeError('issubclass() arg 1 must be a class')
+
+ # Second, perform the actual structural compatibility check.
+ for attr in _get_protocol_attrs(cls):
+ for base in other.__mro__:
+ # Check if the members appears in the class dictionary...
+ if attr in base.__dict__:
+ if base.__dict__[attr] is None:
+ return NotImplemented
+ break
+
+ # ...or in annotations, if it is a sub-protocol.
+ annotations = getattr(base, '__annotations__', {})
+ if (isinstance(annotations, collections.abc.Mapping) and
+ attr in annotations and
+ issubclass(other, Generic) and other._is_protocol):
+ break
+ else:
+ return NotImplemented
+ return True
+
+ if '__subclasshook__' not in cls.__dict__:
+ cls.__subclasshook__ = _proto_hook
+
+ # We have nothing more to do for non-protocols...
+ if not cls._is_protocol:
+ return
+
+ # ... otherwise check consistency of bases, and prohibit instantiation.
+ for base in cls.__bases__:
+ if not (base in (object, Generic) or
+ base.__module__ in _PROTO_WHITELIST and
+ base.__name__ in _PROTO_WHITELIST[base.__module__] or
+ issubclass(base, Generic) and base._is_protocol):
+ raise TypeError('Protocols can only inherit from other'
+ ' protocols, got %r' % base)
+ cls.__init__ = _no_init_or_replace_init
+
+
+class _AnnotatedAlias(_GenericAlias, _root=True):
+ """Runtime representation of an annotated type.
+
+ At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't'
+ with extra annotations. The alias behaves like a normal typing alias,
+ instantiating is the same as instantiating the underlying type, binding
+ it to types is also the same.
+ """
+
+ def __init__(self, origin, metadata):
+ if isinstance(origin, _AnnotatedAlias):
+ metadata = origin.__metadata__ + metadata
+ origin = origin.__origin__
+ super().__init__(origin, origin)
+ self.__metadata__ = metadata
+
+ def copy_with(self, params):
+ assert len(params) == 1
+ new_type = params[0]
+ return _AnnotatedAlias(new_type, self.__metadata__)
+
+ def __repr__(self):
+ return "typing.Annotated[{}, {}]".format(
+ _type_repr(self.__origin__),
+ ", ".join(repr(a) for a in self.__metadata__)
+ )
+
+ def __reduce__(self):
+ return operator.getitem, (
+ Annotated, (self.__origin__,) + self.__metadata__
+ )
+
+ def __eq__(self, other):
+ if not isinstance(other, _AnnotatedAlias):
+ return NotImplemented
+ return (self.__origin__ == other.__origin__
+ and self.__metadata__ == other.__metadata__)
+
+ def __hash__(self):
+ return hash((self.__origin__, self.__metadata__))
+
+
+class Annotated:
+ """Add context specific metadata to a type.
+
+ Example: Annotated[int, runtime_check.Unsigned] indicates to the
+ hypothetical runtime_check module that this type is an unsigned int.
+ Every other consumer of this type can ignore this metadata and treat
+ this type as int.
+
+ The first argument to Annotated must be a valid type.
+
+ Details:
+
+ - It's an error to call `Annotated` with less than two arguments.
+ - Nested Annotated are flattened::
+
+ Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3]
+
+ - Instantiating an annotated type is equivalent to instantiating the
+ underlying type::
+
+ Annotated[C, Ann1](5) == C(5)
+
+ - Annotated can be used as a generic type alias::
+
+ Optimized = Annotated[T, runtime.Optimize()]
+ Optimized[int] == Annotated[int, runtime.Optimize()]
+
+ OptimizedList = Annotated[List[T], runtime.Optimize()]
+ OptimizedList[int] == Annotated[List[int], runtime.Optimize()]
+ """
+
+ __slots__ = ()
+
+ def __new__(cls, *args, **kwargs):
+ raise TypeError("Type Annotated cannot be instantiated.")
+
+ @_tp_cache
+ def __class_getitem__(cls, params):
+ if not isinstance(params, tuple) or len(params) < 2:
+ raise TypeError("Annotated[...] should be used "
+ "with at least two arguments (a type and an "
+ "annotation).")
+ msg = "Annotated[t, ...]: t must be a type."
+ origin = _type_check(params[0], msg, allow_special_forms=True)
+ metadata = tuple(params[1:])
+ return _AnnotatedAlias(origin, metadata)
+
+ def __init_subclass__(cls, *args, **kwargs):
+ raise TypeError(
+ "Cannot subclass {}.Annotated".format(cls.__module__)
+ )
+
+
+def runtime_checkable(cls):
+ """Mark a protocol class as a runtime protocol.
+
+ Such protocol can be used with isinstance() and issubclass().
+ Raise TypeError if applied to a non-protocol class.
+ This allows a simple-minded structural check very similar to
+ one trick ponies in collections.abc such as Iterable.
+ For example::
+
+ @runtime_checkable
+ class Closable(Protocol):
+ def close(self): ...
+
+ assert isinstance(open('/some/file'), Closable)
+
+ Warning: this will check only the presence of the required methods,
+ not their type signatures!
+ """
+ if not issubclass(cls, Generic) or not cls._is_protocol:
+ raise TypeError('@runtime_checkable can be only applied to protocol classes,'
+ ' got %r' % cls)
+ cls._is_runtime_protocol = True
+ return cls
+
+
+def cast(typ, val):
+ """Cast a value to a type.
+
+ This returns the value unchanged. To the type checker this
+ signals that the return value has the designated type, but at
+ runtime we intentionally don't check anything (we want this
+ to be as fast as possible).
+ """
+ return val
+
+
+def _get_defaults(func):
+ """Internal helper to extract the default arguments, by name."""
+ try:
+ code = func.__code__
+ except AttributeError:
+ # Some built-in functions don't have __code__, __defaults__, etc.
+ return {}
+ pos_count = code.co_argcount
+ arg_names = code.co_varnames
+ arg_names = arg_names[:pos_count]
+ defaults = func.__defaults__ or ()
+ kwdefaults = func.__kwdefaults__
+ res = dict(kwdefaults) if kwdefaults else {}
+ pos_offset = pos_count - len(defaults)
+ for name, value in zip(arg_names[pos_offset:], defaults):
+ assert name not in res
+ res[name] = value
+ return res
+
+
+_allowed_types = (types.FunctionType, types.BuiltinFunctionType,
+ types.MethodType, types.ModuleType,
+ WrapperDescriptorType, MethodWrapperType, MethodDescriptorType)
+
+
+def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
+ """Return type hints for an object.
+
+ This is often the same as obj.__annotations__, but it handles
+ forward references encoded as string literals, adds Optional[t] if a
+ default value equal to None is set and recursively replaces all
+ 'Annotated[T, ...]' with 'T' (unless 'include_extras=True').
+
+ The argument may be a module, class, method, or function. The annotations
+ are returned as a dictionary. For classes, annotations include also
+ inherited members.
+
+ TypeError is raised if the argument is not of a type that can contain
+ annotations, and an empty dictionary is returned if no annotations are
+ present.
+
+ BEWARE -- the behavior of globalns and localns is counterintuitive
+ (unless you are familiar with how eval() and exec() work). The
+ search order is locals first, then globals.
+
+ - If no dict arguments are passed, an attempt is made to use the
+ globals from obj (or the respective module's globals for classes),
+ and these are also used as the locals. If the object does not appear
+ to have globals, an empty dictionary is used.
+
+ - If one dict argument is passed, it is used for both globals and
+ locals.
+
+ - If two dict arguments are passed, they specify globals and
+ locals, respectively.
+ """
+
+ if getattr(obj, '__no_type_check__', None):
+ return {}
+ # Classes require a special treatment.
+ if isinstance(obj, type):
+ hints = {}
+ for base in reversed(obj.__mro__):
+ if globalns is None:
+ base_globals = sys.modules[base.__module__].__dict__
+ else:
+ base_globals = globalns
+ ann = base.__dict__.get('__annotations__', {})
+ for name, value in ann.items():
+ if value is None:
+ value = type(None)
+ if isinstance(value, str):
+ value = ForwardRef(value, is_argument=False, is_class=True)
+ value = _eval_type(value, base_globals, localns)
+ hints[name] = value
+ return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}
+
+ if globalns is None:
+ if isinstance(obj, types.ModuleType):
+ globalns = obj.__dict__
+ else:
+ nsobj = obj
+ # Find globalns for the unwrapped object.
+ while hasattr(nsobj, '__wrapped__'):
+ nsobj = nsobj.__wrapped__
+ globalns = getattr(nsobj, '__globals__', {})
+ if localns is None:
+ localns = globalns
+ elif localns is None:
+ localns = globalns
+ hints = getattr(obj, '__annotations__', None)
+ if hints is None:
+ # Return empty annotations for something that _could_ have them.
+ if isinstance(obj, _allowed_types):
+ return {}
+ else:
+ raise TypeError('{!r} is not a module, class, method, '
+ 'or function.'.format(obj))
+ defaults = _get_defaults(obj)
+ hints = dict(hints)
+ for name, value in hints.items():
+ if value is None:
+ value = type(None)
+ if isinstance(value, str):
+ # class-level forward refs were handled above, this must be either
+ # a module-level annotation or a function argument annotation
+ value = ForwardRef(
+ value,
+ is_argument=not isinstance(obj, types.ModuleType),
+ is_class=False,
+ )
+ value = _eval_type(value, globalns, localns)
+ if name in defaults and defaults[name] is None:
+ value = Optional[value]
+ hints[name] = value
+ return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}
+
+
+def _strip_annotations(t):
+ """Strips the annotations from a given type.
+ """
+ if isinstance(t, _AnnotatedAlias):
+ return _strip_annotations(t.__origin__)
+ if isinstance(t, _GenericAlias):
+ stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
+ if stripped_args == t.__args__:
+ return t
+ return t.copy_with(stripped_args)
+ if isinstance(t, GenericAlias):
+ stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
+ if stripped_args == t.__args__:
+ return t
+ return GenericAlias(t.__origin__, stripped_args)
+ return t
+
+
+def get_origin(tp):
+ """Get the unsubscripted version of a type.
+
+ This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar
+ and Annotated. Return None for unsupported types. Examples::
+
+ get_origin(Literal[42]) is Literal
+ get_origin(int) is None
+ get_origin(ClassVar[int]) is ClassVar
+ get_origin(Generic) is Generic
+ get_origin(Generic[T]) is Generic
+ get_origin(Union[T, int]) is Union
+ get_origin(List[Tuple[T, T]][int]) == list
+ """
+ if isinstance(tp, _AnnotatedAlias):
+ return Annotated
+ if isinstance(tp, (_BaseGenericAlias, GenericAlias)):
+ return tp.__origin__
+ if tp is Generic:
+ return Generic
+ return None
+
+
+def get_args(tp):
+ """Get type arguments with all substitutions performed.
+
+ For unions, basic simplifications used by Union constructor are performed.
+ Examples::
+ get_args(Dict[str, int]) == (str, int)
+ get_args(int) == ()
+ get_args(Union[int, Union[T, int], str][int]) == (int, str)
+ get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
+ get_args(Callable[[], T][int]) == ([], int)
+ """
+ if isinstance(tp, _AnnotatedAlias):
+ return (tp.__origin__,) + tp.__metadata__
+ if isinstance(tp, (_GenericAlias, GenericAlias)):
+ res = tp.__args__
+ if tp.__origin__ is collections.abc.Callable and res[0] is not Ellipsis:
+ res = (list(res[:-1]), res[-1])
+ return res
+ return ()
+
+
+def no_type_check(arg):
+ """Decorator to indicate that annotations are not type hints.
+
+ The argument must be a class or function; if it is a class, it
+ applies recursively to all methods and classes defined in that class
+ (but not to methods defined in its superclasses or subclasses).
+
+ This mutates the function(s) or class(es) in place.
+ """
+ if isinstance(arg, type):
+ arg_attrs = arg.__dict__.copy()
+ for attr, val in arg.__dict__.items():
+ if val in arg.__bases__ + (arg,):
+ arg_attrs.pop(attr)
+ for obj in arg_attrs.values():
+ if isinstance(obj, types.FunctionType):
+ obj.__no_type_check__ = True
+ if isinstance(obj, type):
+ no_type_check(obj)
+ try:
+ arg.__no_type_check__ = True
+ except TypeError: # built-in classes
+ pass
+ return arg
+
+
+def no_type_check_decorator(decorator):
+ """Decorator to give another decorator the @no_type_check effect.
+
+ This wraps the decorator with something that wraps the decorated
+ function in @no_type_check.
+ """
+
+ @functools.wraps(decorator)
+ def wrapped_decorator(*args, **kwds):
+ func = decorator(*args, **kwds)
+ func = no_type_check(func)
+ return func
+
+ return wrapped_decorator
+
+
+def _overload_dummy(*args, **kwds):
+ """Helper for @overload to raise when called."""
+ raise NotImplementedError(
+ "You should not call an overloaded function. "
+ "A series of @overload-decorated functions "
+ "outside a stub module should always be followed "
+ "by an implementation that is not @overload-ed.")
+
+
+def overload(func):
+ """Decorator for overloaded functions/methods.
+
+ In a stub file, place two or more stub definitions for the same
+ function in a row, each decorated with @overload. For example:
+
+ @overload
+ def utf8(value: None) -> None: ...
+ @overload
+ def utf8(value: bytes) -> bytes: ...
+ @overload
+ def utf8(value: str) -> bytes: ...
+
+ In a non-stub file (i.e. a regular .py file), do the same but
+ follow it with an implementation. The implementation should *not*
+ be decorated with @overload. For example:
+
+ @overload
+ def utf8(value: None) -> None: ...
+ @overload
+ def utf8(value: bytes) -> bytes: ...
+ @overload
+ def utf8(value: str) -> bytes: ...
+ def utf8(value):
+ # implementation goes here
+ """
+ return _overload_dummy
+
+
+def final(f):
+ """A decorator to indicate final methods and final classes.
+
+ Use this decorator to indicate to type checkers that the decorated
+ method cannot be overridden, and decorated class cannot be subclassed.
+ For example:
+
+ class Base:
+ @final
+ def done(self) -> None:
+ ...
+ class Sub(Base):
+ def done(self) -> None: # Error reported by type checker
+ ...
+
+ @final
+ class Leaf:
+ ...
+ class Other(Leaf): # Error reported by type checker
+ ...
+
+ There is no runtime checking of these properties.
+ """
+ return f
+
+
+# Some unconstrained type variables. These are used by the container types.
+# (These are not for export.)
+T = TypeVar('T') # Any type.
+KT = TypeVar('KT') # Key type.
+VT = TypeVar('VT') # Value type.
+T_co = TypeVar('T_co', covariant=True) # Any type covariant containers.
+V_co = TypeVar('V_co', covariant=True) # Any type covariant containers.
+VT_co = TypeVar('VT_co', covariant=True) # Value type covariant containers.
+T_contra = TypeVar('T_contra', contravariant=True) # Ditto contravariant.
+# Internal type variable used for Type[].
+CT_co = TypeVar('CT_co', covariant=True, bound=type)
+
+# A useful type variable with constraints. This represents string types.
+# (This one *is* for export!)
+AnyStr = TypeVar('AnyStr', bytes, str)
+
+# Various ABCs mimicking those in collections.abc.
+_alias = _SpecialGenericAlias
+
+Hashable = _alias(collections.abc.Hashable, 0) # Not generic.
+Awaitable = _alias(collections.abc.Awaitable, 1)
+Coroutine = _alias(collections.abc.Coroutine, 3)
+AsyncIterable = _alias(collections.abc.AsyncIterable, 1)
+AsyncIterator = _alias(collections.abc.AsyncIterator, 1)
+Iterable = _alias(collections.abc.Iterable, 1)
+Iterator = _alias(collections.abc.Iterator, 1)
+Reversible = _alias(collections.abc.Reversible, 1)
+Sized = _alias(collections.abc.Sized, 0) # Not generic.
+Container = _alias(collections.abc.Container, 1)
+Collection = _alias(collections.abc.Collection, 1)
+Callable = _CallableType(collections.abc.Callable, 2)
+Callable.__doc__ = \
+ """Callable type; Callable[[int], str] is a function of (int) -> str.
+
+ The subscription syntax must always be used with exactly two
+ values: the argument list and the return type. The argument list
+ must be a list of types or ellipsis; the return type must be a single type.
+
+ There is no syntax to indicate optional or keyword arguments,
+ such function types are rarely used as callback types.
+ """
+AbstractSet = _alias(collections.abc.Set, 1, name='AbstractSet')
+MutableSet = _alias(collections.abc.MutableSet, 1)
+# NOTE: Mapping is only covariant in the value type.
+Mapping = _alias(collections.abc.Mapping, 2)
+MutableMapping = _alias(collections.abc.MutableMapping, 2)
+Sequence = _alias(collections.abc.Sequence, 1)
+MutableSequence = _alias(collections.abc.MutableSequence, 1)
+ByteString = _alias(collections.abc.ByteString, 0) # Not generic
+# Tuple accepts variable number of parameters.
+Tuple = _TupleType(tuple, -1, inst=False, name='Tuple')
+Tuple.__doc__ = \
+ """Tuple type; Tuple[X, Y] is the cross-product type of X and Y.
+
+ Example: Tuple[T1, T2] is a tuple of two elements corresponding
+ to type variables T1 and T2. Tuple[int, float, str] is a tuple
+ of an int, a float and a string.
+
+ To specify a variable-length tuple of homogeneous type, use Tuple[T, ...].
+ """
+List = _alias(list, 1, inst=False, name='List')
+Deque = _alias(collections.deque, 1, name='Deque')
+Set = _alias(set, 1, inst=False, name='Set')
+FrozenSet = _alias(frozenset, 1, inst=False, name='FrozenSet')
+MappingView = _alias(collections.abc.MappingView, 1)
+KeysView = _alias(collections.abc.KeysView, 1)
+ItemsView = _alias(collections.abc.ItemsView, 2)
+ValuesView = _alias(collections.abc.ValuesView, 1)
+ContextManager = _alias(contextlib.AbstractContextManager, 1, name='ContextManager')
+AsyncContextManager = _alias(contextlib.AbstractAsyncContextManager, 1, name='AsyncContextManager')
+Dict = _alias(dict, 2, inst=False, name='Dict')
+DefaultDict = _alias(collections.defaultdict, 2, name='DefaultDict')
+OrderedDict = _alias(collections.OrderedDict, 2)
+Counter = _alias(collections.Counter, 1)
+ChainMap = _alias(collections.ChainMap, 2)
+Generator = _alias(collections.abc.Generator, 3)
+AsyncGenerator = _alias(collections.abc.AsyncGenerator, 2)
+Type = _alias(type, 1, inst=False, name='Type')
+Type.__doc__ = \
+ """A special construct usable to annotate class objects.
+
+ For example, suppose we have the following classes::
+
+ class User: ... # Abstract base for User classes
+ class BasicUser(User): ...
+ class ProUser(User): ...
+ class TeamUser(User): ...
+
+ And a function that takes a class argument that's a subclass of
+ User and returns an instance of the corresponding class::
+
+ U = TypeVar('U', bound=User)
+ def new_user(user_class: Type[U]) -> U:
+ user = user_class()
+ # (Here we could write the user object to a database)
+ return user
+
+ joe = new_user(BasicUser)
+
+ At this point the type checker knows that joe has type BasicUser.
+ """
+
+
+@runtime_checkable
+class SupportsInt(Protocol):
+ """An ABC with one abstract method __int__."""
+ __slots__ = ()
+
+ @abstractmethod
+ def __int__(self) -> int:
+ pass
+
+
+@runtime_checkable
+class SupportsFloat(Protocol):
+ """An ABC with one abstract method __float__."""
+ __slots__ = ()
+
+ @abstractmethod
+ def __float__(self) -> float:
+ pass
+
+
+@runtime_checkable
+class SupportsComplex(Protocol):
+ """An ABC with one abstract method __complex__."""
+ __slots__ = ()
+
+ @abstractmethod
+ def __complex__(self) -> complex:
+ pass
+
+
+@runtime_checkable
+class SupportsBytes(Protocol):
+ """An ABC with one abstract method __bytes__."""
+ __slots__ = ()
+
+ @abstractmethod
+ def __bytes__(self) -> bytes:
+ pass
+
+
+@runtime_checkable
+class SupportsIndex(Protocol):
+ """An ABC with one abstract method __index__."""
+ __slots__ = ()
+
+ @abstractmethod
+ def __index__(self) -> int:
+ pass
+
+
+@runtime_checkable
+class SupportsAbs(Protocol[T_co]):
+ """An ABC with one abstract method __abs__ that is covariant in its return type."""
+ __slots__ = ()
+
+ @abstractmethod
+ def __abs__(self) -> T_co:
+ pass
+
+
+@runtime_checkable
+class SupportsRound(Protocol[T_co]):
+ """An ABC with one abstract method __round__ that is covariant in its return type."""
+ __slots__ = ()
+
+ @abstractmethod
+ def __round__(self, ndigits: int = 0) -> T_co:
+ pass
+
+
+def _make_nmtuple(name, types, module, defaults=()):
+ fields = [n for n, t in types]
+ types = {n: _type_check(t, f"field {n} annotation must be a type")
+ for n, t in types}
+ nm_tpl = collections.namedtuple(name, fields,
+ defaults=defaults, module=module)
+ nm_tpl.__annotations__ = nm_tpl.__new__.__annotations__ = types
+ return nm_tpl
+
+
+# attributes prohibited to set in NamedTuple class syntax
+_prohibited = frozenset({'__new__', '__init__', '__slots__', '__getnewargs__',
+ '_fields', '_field_defaults',
+ '_make', '_replace', '_asdict', '_source'})
+
+_special = frozenset({'__module__', '__name__', '__annotations__'})
+
+
+class NamedTupleMeta(type):
+
+ def __new__(cls, typename, bases, ns):
+ assert bases[0] is _NamedTuple
+ types = ns.get('__annotations__', {})
+ default_names = []
+ for field_name in types:
+ if field_name in ns:
+ default_names.append(field_name)
+ elif default_names:
+ raise TypeError(f"Non-default namedtuple field {field_name} "
+ f"cannot follow default field"
+ f"{'s' if len(default_names) > 1 else ''} "
+ f"{', '.join(default_names)}")
+ nm_tpl = _make_nmtuple(typename, types.items(),
+ defaults=[ns[n] for n in default_names],
+ module=ns['__module__'])
+ # update from user namespace without overriding special namedtuple attributes
+ for key in ns:
+ if key in _prohibited:
+ raise AttributeError("Cannot overwrite NamedTuple attribute " + key)
+ elif key not in _special and key not in nm_tpl._fields:
+ setattr(nm_tpl, key, ns[key])
+ return nm_tpl
+
+
+def NamedTuple(typename, fields=None, /, **kwargs):
+ """Typed version of namedtuple.
+
+ Usage in Python versions >= 3.6::
+
+ class Employee(NamedTuple):
+ name: str
+ id: int
+
+ This is equivalent to::
+
+ Employee = collections.namedtuple('Employee', ['name', 'id'])
+
+ The resulting class has an extra __annotations__ attribute, giving a
+ dict that maps field names to types. (The field names are also in
+ the _fields attribute, which is part of the namedtuple API.)
+ Alternative equivalent keyword syntax is also accepted::
+
+ Employee = NamedTuple('Employee', name=str, id=int)
+
+ In Python versions <= 3.5 use::
+
+ Employee = NamedTuple('Employee', [('name', str), ('id', int)])
+ """
+ if fields is None:
+ fields = kwargs.items()
+ elif kwargs:
+ raise TypeError("Either list of fields or keywords"
+ " can be provided to NamedTuple, not both")
+ try:
+ module = sys._getframe(1).f_globals.get('__name__', '__main__')
+ except (AttributeError, ValueError):
+ module = None
+ return _make_nmtuple(typename, fields, module=module)
+
+
+_NamedTuple = type.__new__(NamedTupleMeta, 'NamedTuple', (), {})
+
+
+def _namedtuple_mro_entries(bases):
+ if len(bases) > 1:
+ raise TypeError("Multiple inheritance with NamedTuple is not supported")
+ assert bases[0] is NamedTuple
+ return (_NamedTuple,)
+
+
+NamedTuple.__mro_entries__ = _namedtuple_mro_entries
+
+
+class _TypedDictMeta(type):
+ def __new__(cls, name, bases, ns, total=True):
+ """Create new typed dict class object.
+
+ This method is called when TypedDict is subclassed,
+ or when TypedDict is instantiated. This way
+ TypedDict supports all three syntax forms described in its docstring.
+ Subclasses and instances of TypedDict return actual dictionaries.
+ """
+ for base in bases:
+ if type(base) is not _TypedDictMeta:
+ raise TypeError('cannot inherit from both a TypedDict type '
+ 'and a non-TypedDict base class')
+ tp_dict = type.__new__(_TypedDictMeta, name, (dict,), ns)
+
+ annotations = {}
+ own_annotations = ns.get('__annotations__', {})
+ own_annotation_keys = set(own_annotations.keys())
+ msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
+ own_annotations = {
+ n: _type_check(tp, msg, module=tp_dict.__module__)
+ for n, tp in own_annotations.items()
+ }
+ required_keys = set()
+ optional_keys = set()
+
+ for base in bases:
+ annotations.update(base.__dict__.get('__annotations__', {}))
+ required_keys.update(base.__dict__.get('__required_keys__', ()))
+ optional_keys.update(base.__dict__.get('__optional_keys__', ()))
+
+ annotations.update(own_annotations)
+ if total:
+ required_keys.update(own_annotation_keys)
+ else:
+ optional_keys.update(own_annotation_keys)
+
+ tp_dict.__annotations__ = annotations
+ tp_dict.__required_keys__ = frozenset(required_keys)
+ tp_dict.__optional_keys__ = frozenset(optional_keys)
+ if not hasattr(tp_dict, '__total__'):
+ tp_dict.__total__ = total
+ return tp_dict
+
+ __call__ = dict # static method
+
+ def __subclasscheck__(cls, other):
+ # Typed dicts are only for static structural subtyping.
+ raise TypeError('TypedDict does not support instance and class checks')
+
+ __instancecheck__ = __subclasscheck__
+
+
+def TypedDict(typename, fields=None, /, *, total=True, **kwargs):
+ """A simple typed namespace. At runtime it is equivalent to a plain dict.
+
+ TypedDict creates a dictionary type that expects all of its
+ instances to have a certain set of keys, where each key is
+ associated with a value of a consistent type. This expectation
+ is not checked at runtime but is only enforced by type checkers.
+ Usage::
+
+ class Point2D(TypedDict):
+ x: int
+ y: int
+ label: str
+
+ a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK
+ b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check
+
+ assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first')
+
+ The type info can be accessed via the Point2D.__annotations__ dict, and
+ the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets.
+ TypedDict supports two additional equivalent forms::
+
+ Point2D = TypedDict('Point2D', x=int, y=int, label=str)
+ Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str})
+
+ By default, all keys must be present in a TypedDict. It is possible
+ to override this by specifying totality.
+ Usage::
+
+ class point2D(TypedDict, total=False):
+ x: int
+ y: int
+
+ This means that a point2D TypedDict can have any of the keys omitted.A type
+ checker is only expected to support a literal False or True as the value of
+ the total argument. True is the default, and makes all items defined in the
+ class body be required.
+
+ The class syntax is only supported in Python 3.6+, while two other
+ syntax forms work for Python 2.7 and 3.2+
+ """
+ if fields is None:
+ fields = kwargs
+ elif kwargs:
+ raise TypeError("TypedDict takes either a dict or keyword arguments,"
+ " but not both")
+
+ ns = {'__annotations__': dict(fields)}
+ try:
+ # Setting correct module is necessary to make typed dict classes pickleable.
+ ns['__module__'] = sys._getframe(1).f_globals.get('__name__', '__main__')
+ except (AttributeError, ValueError):
+ pass
+
+ return _TypedDictMeta(typename, (), ns, total=total)
+
+
+_TypedDict = type.__new__(_TypedDictMeta, 'TypedDict', (), {})
+TypedDict.__mro_entries__ = lambda bases: (_TypedDict,)
+
+
+def NewType(name, tp):
+ """NewType creates simple unique types with almost zero
+ runtime overhead. NewType(name, tp) is considered a subtype of tp
+ by static type checkers. At runtime, NewType(name, tp) returns
+ a dummy function that simply returns its argument. Usage::
+
+ UserId = NewType('UserId', int)
+
+ def name_by_id(user_id: UserId) -> str:
+ ...
+
+ UserId('user') # Fails type check
+
+ name_by_id(42) # Fails type check
+ name_by_id(UserId(42)) # OK
+
+ num = UserId(5) + 1 # type: int
+ """
+
+ def new_type(x):
+ return x
+
+ new_type.__name__ = name
+ new_type.__supertype__ = tp
+ return new_type
+
+
+# Python-version-specific alias (Python 2: unicode; Python 3: str)
+Text = str
+
+# Constant that's True when type checking, but False here.
+TYPE_CHECKING = False
+
+
+class IO(Generic[AnyStr]):
+ """Generic base class for TextIO and BinaryIO.
+
+ This is an abstract, generic version of the return of open().
+
+ NOTE: This does not distinguish between the different possible
+ classes (text vs. binary, read vs. write vs. read/write,
+ append-only, unbuffered). The TextIO and BinaryIO subclasses
+ below capture the distinctions between text vs. binary, which is
+ pervasive in the interface; however we currently do not offer a
+ way to track the other distinctions in the type system.
+ """
+
+ __slots__ = ()
+
+ @property
+ @abstractmethod
+ def mode(self) -> str:
+ pass
+
+ @property
+ @abstractmethod
+ def name(self) -> str:
+ pass
+
+ @abstractmethod
+ def close(self) -> None:
+ pass
+
+ @property
+ @abstractmethod
+ def closed(self) -> bool:
+ pass
+
+ @abstractmethod
+ def fileno(self) -> int:
+ pass
+
+ @abstractmethod
+ def flush(self) -> None:
+ pass
+
+ @abstractmethod
+ def isatty(self) -> bool:
+ pass
+
+ @abstractmethod
+ def read(self, n: int = -1) -> AnyStr:
+ pass
+
+ @abstractmethod
+ def readable(self) -> bool:
+ pass
+
+ @abstractmethod
+ def readline(self, limit: int = -1) -> AnyStr:
+ pass
+
+ @abstractmethod
+ def readlines(self, hint: int = -1) -> List[AnyStr]:
+ pass
+
+ @abstractmethod
+ def seek(self, offset: int, whence: int = 0) -> int:
+ pass
+
+ @abstractmethod
+ def seekable(self) -> bool:
+ pass
+
+ @abstractmethod
+ def tell(self) -> int:
+ pass
+
+ @abstractmethod
+ def truncate(self, size: int = None) -> int:
+ pass
+
+ @abstractmethod
+ def writable(self) -> bool:
+ pass
+
+ @abstractmethod
+ def write(self, s: AnyStr) -> int:
+ pass
+
+ @abstractmethod
+ def writelines(self, lines: List[AnyStr]) -> None:
+ pass
+
+ @abstractmethod
+ def __enter__(self) -> 'IO[AnyStr]':
+ pass
+
+ @abstractmethod
+ def __exit__(self, type, value, traceback) -> None:
+ pass
+
+
+class BinaryIO(IO[bytes]):
+ """Typed version of the return of open() in binary mode."""
+
+ __slots__ = ()
+
+ @abstractmethod
+ def write(self, s: Union[bytes, bytearray]) -> int:
+ pass
+
+ @abstractmethod
+ def __enter__(self) -> 'BinaryIO':
+ pass
+
+
+class TextIO(IO[str]):
+ """Typed version of the return of open() in text mode."""
+
+ __slots__ = ()
+
+ @property
+ @abstractmethod
+ def buffer(self) -> BinaryIO:
+ pass
+
+ @property
+ @abstractmethod
+ def encoding(self) -> str:
+ pass
+
+ @property
+ @abstractmethod
+ def errors(self) -> Optional[str]:
+ pass
+
+ @property
+ @abstractmethod
+ def line_buffering(self) -> bool:
+ pass
+
+ @property
+ @abstractmethod
+ def newlines(self) -> Any:
+ pass
+
+ @abstractmethod
+ def __enter__(self) -> 'TextIO':
+ pass
+
+
+class io:
+ """Wrapper namespace for IO generic classes."""
+
+ __all__ = ['IO', 'TextIO', 'BinaryIO']
+ IO = IO
+ TextIO = TextIO
+ BinaryIO = BinaryIO
+
+
+io.__name__ = __name__ + '.io'
+sys.modules[io.__name__] = io
+
+Pattern = _alias(stdlib_re.Pattern, 1)
+Match = _alias(stdlib_re.Match, 1)
+
+
+class re:
+ """Wrapper namespace for re type aliases."""
+
+ __all__ = ['Pattern', 'Match']
+ Pattern = Pattern
+ Match = Match
+
+
+re.__name__ = __name__ + '.re'
+sys.modules[re.__name__] = re
diff --git a/brainpy/channels.py b/brainpy/channels.py
index 16769e2f1..b471c1194 100644
--- a/brainpy/channels.py
+++ b/brainpy/channels.py
@@ -1,58 +1,3 @@
# -*- coding: utf-8 -*-
-from brainpy._src.dyn.channels.base import (
- Ion as Ion,
- IonChannel as IonChannel,
- Calcium as Calcium,
- IhChannel as IhChannel,
- CalciumChannel as CalciumChannel,
- SodiumChannel as SodiumChannel,
- PotassiumChannel as PotassiumChannel,
- LeakyChannel as LeakyChannel,
-)
-
-from brainpy._src.dyn.channels.Ca import (
- CalciumFixed as CalciumFixed,
- CalciumDyna as CalciumDyna,
- CalciumDetailed as CalciumDetailed,
- CalciumFirstOrder as CalciumFirstOrder,
- ICaN_IS2008 as ICaN_IS2008,
- ICaT_HM1992 as ICaT_HM1992,
- ICaT_HP1992 as ICaT_HP1992,
- ICaHT_HM1992 as ICaHT_HM1992,
- ICaL_IS2008 as ICaL_IS2008,
-)
-
-from brainpy._src.dyn.channels.IH import (
- Ih_HM1992 as Ih_HM1992,
- Ih_De1996 as Ih_De1996,
-)
-
-from brainpy._src.dyn.channels.K import (
- IKDR_Ba2002 as IKDR_Ba2002,
- IK_TM1991 as IK_TM1991,
- IK_HH1952 as IK_HH1952,
- IKA1_HM1992 as IKA1_HM1992,
- IKA2_HM1992 as IKA2_HM1992,
- IKK2A_HM1992 as IKK2A_HM1992,
- IKK2B_HM1992 as IKK2B_HM1992,
- IKNI_Ya1989 as IKNI_Ya1989,
-)
-
-from brainpy._src.dyn.channels.KCa import (
- IAHP_De1994 as IAHP_De1994,
-)
-
-from brainpy._src.dyn.channels.leaky import (
- IL as IL,
- IKL as IKL,
-)
-
-from brainpy._src.dyn.channels.Na import (
- INa_Ba2002 as INa_Ba2002,
- INa_TM1991 as INa_TM1991,
- INa_HH1952 as INa_HH1952,
-)
-
-
-
+from .dyn.channels import *
diff --git a/brainpy/dyn/__init__.py b/brainpy/dyn/__init__.py
index 049a0c364..6471e011d 100644
--- a/brainpy/dyn/__init__.py
+++ b/brainpy/dyn/__init__.py
@@ -1,6 +1,8 @@
+from .ions import *
from .channels import *
from .neurons import *
from .synapses import *
from .projections import *
from .others import *
+from .outs import *
diff --git a/brainpy/dyn/channels.py b/brainpy/dyn/channels.py
index f4f0d0283..df5bdd927 100644
--- a/brainpy/dyn/channels.py
+++ b/brainpy/dyn/channels.py
@@ -1,20 +1,9 @@
from brainpy._src.dyn.channels.base import (
- Ion,
IonChannel,
- Calcium,
- IhChannel,
- CalciumChannel,
- SodiumChannel,
- PotassiumChannel,
- LeakyChannel,
)
+from brainpy._src.dyn.channels.base import CalciumChannel
from brainpy._src.dyn.channels.Ca import (
- CalciumFixed,
- CalciumChannel,
- CalciumDetailed,
- CalciumFirstOrder,
- CalciumDyna,
ICaN_IS2008,
ICaT_HM1992,
ICaT_HP1992,
@@ -22,6 +11,8 @@
ICaL_IS2008,
)
+
+from brainpy._src.dyn.channels.base import PotassiumChannel
from brainpy._src.dyn.channels.K import (
IKDR_Ba2002,
IK_TM1991,
@@ -33,22 +24,30 @@
IKNI_Ya1989,
)
+
+from brainpy._src.dyn.channels.base import IhChannel
from brainpy._src.dyn.channels.IH import (
Ih_HM1992,
Ih_De1996,
)
+
from brainpy._src.dyn.channels.KCa import (
IAHP_De1994
)
+
+from brainpy._src.dyn.channels.base import SodiumChannel
from brainpy._src.dyn.channels.Na import (
INa_Ba2002,
INa_TM1991,
INa_HH1952,
)
+
+from brainpy._src.dyn.channels.base import LeakyChannel
from brainpy._src.dyn.channels.leaky import (
IL,
IKL,
)
+
diff --git a/brainpy/dyn/ions.py b/brainpy/dyn/ions.py
new file mode 100644
index 000000000..8f040c971
--- /dev/null
+++ b/brainpy/dyn/ions.py
@@ -0,0 +1,12 @@
+
+from brainpy._src.dyn.ions.base import (
+ Ion as Ion,
+ Calcium as Calcium,
+)
+
+from brainpy._src.dyn.ions.ca import (
+ CalciumFixed as CalciumFixed,
+ CalciumDetailed as CalciumDetailed,
+ CalciumFirstOrder as CalciumFirstOrder,
+)
+
diff --git a/brainpy/dyn/neurons.py b/brainpy/dyn/neurons.py
index 61ab26852..ae4d06ee8 100644
--- a/brainpy/dyn/neurons.py
+++ b/brainpy/dyn/neurons.py
@@ -1,10 +1,4 @@
-from brainpy._src.dyn.base import (
- NeuDyn,
- GradNeuDyn,
- HHTypeNeu,
- HHTypeNeuLTC
-)
from brainpy._src.dyn.neurons.lif import (
Lif,
@@ -38,19 +32,16 @@
)
from brainpy._src.dyn.neurons.hh import (
+ CondNeuGroupLTC,
+ CondNeuGroup,
HH,
HHLTC,
MorrisLecar,
MorrisLecarLTC,
- WangBuzsakiModel,
- WangBuzsakiModelLTC,
+ WangBuzsakiHH,
+ WangBuzsakiHHLTC,
)
-from brainpy._src.dyn.neurons.input import (
- InputGroup,
- OutputGroup,
- SpikeTimeGroup,
- PoissonGroup,
-)
+
diff --git a/brainpy/dyn/others.py b/brainpy/dyn/others.py
index 1183608f5..8ecd9bf8b 100644
--- a/brainpy/dyn/others.py
+++ b/brainpy/dyn/others.py
@@ -1,4 +1,17 @@
from brainpy._src.dyn.others.common import (
- Leaky,
- Integrator,
-)
\ No newline at end of file
+ Leaky as Leaky,
+ Integrator as Integrator,
+)
+
+from brainpy._src.dyn.others.input import (
+ InputGroup as InputGroup,
+ OutputGroup as OutputGroup,
+ SpikeTimeGroup as SpikeTimeGroup,
+ PoissonGroup as PoissonGroup,
+)
+
+
+from brainpy._src.dyn.others.noise import (
+ OUProcess as OUProcess,
+)
+
diff --git a/brainpy/dyn/outs.py b/brainpy/dyn/outs.py
new file mode 100644
index 000000000..e2e602d0c
--- /dev/null
+++ b/brainpy/dyn/outs.py
@@ -0,0 +1,8 @@
+from brainpy._src.dyn.outs.base import (
+ SynOut,
+)
+from brainpy._src.dyn.outs.outputs import (
+ COBA,
+ CUBA,
+ MgBlock,
+)
diff --git a/brainpy/dyn/projections.py b/brainpy/dyn/projections.py
index a5448074b..15dde3d57 100644
--- a/brainpy/dyn/projections.py
+++ b/brainpy/dyn/projections.py
@@ -1,7 +1,11 @@
-from brainpy._src.dyn.projections import (
- ProjAlignPost,
- ProjAlignPre,
+from brainpy._src.dyn.projections.aligns import (
+ ProjAlignPost as ProjAlignPost,
+ ProjAlignPre as ProjAlignPre,
+)
+
+from brainpy._src.dyn.projections.others import (
+ PoissonInput as PoissonInput,
)
diff --git a/brainpy/dyn/rates.py b/brainpy/dyn/rates.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/brainpy/dyn/synapses.py b/brainpy/dyn/synapses.py
index 3f92d0102..e59a33826 100644
--- a/brainpy/dyn/synapses.py
+++ b/brainpy/dyn/synapses.py
@@ -1,24 +1,7 @@
-from brainpy._src.dyn.base import (
- SynDyn,
- SynOut,
-)
-
-from brainpy._src.dyn.synapses.dynamics import (
+from brainpy._src.dyn.synapses.abstract_models import (
+ Delta,
Expon,
DualExpon,
- Alpha,
- NMDA,
- STD,
- STP,
- AMPA,
- GABAa,
- BioNMDA,
-)
-
-from brainpy._src.dyn.synapses.outputs import (
- COBA,
- CUBA,
- MgBlock,
)
diff --git a/brainpy/errors.py b/brainpy/errors.py
index b35a4117f..af3d51f0c 100644
--- a/brainpy/errors.py
+++ b/brainpy/errors.py
@@ -232,3 +232,9 @@ def __init__(self, name):
''')
+
+
+class SharedArgError(BrainPyError):
+ pass
+
+
diff --git a/brainpy/experimental.py b/brainpy/experimental.py
index 68d8ff5bd..c909fa633 100644
--- a/brainpy/experimental.py
+++ b/brainpy/experimental.py
@@ -1,18 +1,18 @@
-from brainpy._src.synapses_v2.syn_plasticity import (
+from brainpy._src.dynold.experimental.syn_plasticity import (
STD as STD,
STP as STP,
)
-from brainpy._src.synapses_v2.syn_outs import (
+from brainpy._src.dynold.experimental.syn_outs import (
CUBA as CUBA,
COBA as COBA,
)
-from brainpy._src.synapses_v2.abstract_synapses import (
+from brainpy._src.dynold.experimental.abstract_synapses import (
Exponential,
DualExponential,
Alpha,
)
-from brainpy._src.synapses_v2.others import (
+from brainpy._src.dynold.experimental.others import (
PoissonInput,
)
diff --git a/brainpy/mixin.py b/brainpy/mixin.py
index 09521fd0a..61bd0dca4 100644
--- a/brainpy/mixin.py
+++ b/brainpy/mixin.py
@@ -1,7 +1,11 @@
from brainpy._src.mixin import (
- MixIn,
- AlignPost,
- ProjAutoDelay,
- ParamDesc,
+ MixIn as MixIn,
+ AlignPost as AlignPost,
+ AutoDelaySupp as AutoDelaySupp,
+ ParamDesc as ParamDesc,
+ NoSH as NoSH,
+ Container as Container,
+ TreeNode as TreeNode,
+ JointType as JointType,
)
diff --git a/brainpy/neurons.py b/brainpy/neurons.py
index 0fa154538..e045035a1 100644
--- a/brainpy/neurons.py
+++ b/brainpy/neurons.py
@@ -1,32 +1,19 @@
# -*- coding: utf-8 -*-
-from brainpy._src.neurons.biological_models import (
+from brainpy._src.dynold.neurons.biological_models import (
HH as HH,
MorrisLecar as MorrisLecar,
PinskyRinzelModel as PinskyRinzelModel,
WangBuzsakiModel as WangBuzsakiModel,
)
-from brainpy._src.neurons.fractional_models import (
+from brainpy._src.dynold.neurons.fractional_models import (
FractionalNeuron as FractionalNeuron,
FractionalFHR as FractionalFHR,
FractionalIzhikevich as FractionalIzhikevich,
)
-from brainpy._src.neurons.input_groups import (
- InputGroup as InputGroup,
- OutputGroup as OutputGroup,
- SpikeTimeGroup as SpikeTimeGroup,
- PoissonGroup as PoissonGroup,
-)
-
-from brainpy._src.neurons.noise_groups import (
- OUProcess as OUProcess,
-)
-
-from brainpy._src.neurons.reduced_models import (
- Leaky as Leaky,
- Integrator as Integrator,
+from brainpy._src.dynold.neurons.reduced_models import (
LeakyIntegrator as LeakyIntegrator,
LIF as LIF,
ExpIF as ExpIF,
diff --git a/brainpy/rates.py b/brainpy/rates.py
index 7dedee342..faaaf799c 100644
--- a/brainpy/rates.py
+++ b/brainpy/rates.py
@@ -1,14 +1,3 @@
# -*- coding: utf-8 -*-
-from brainpy._src.rates.populations import (
- RateModel as RateModel,
- FHN as FHN,
- FeedbackFHN as FeedbackFHN,
- QIF as QIF,
- StuartLandauOscillator as StuartLandauOscillator,
- WilsonCowanModel as WilsonCowanModel,
- ThresholdLinearModel as ThresholdLinearModel,
-)
-
-
diff --git a/brainpy/synapses.py b/brainpy/synapses.py
new file mode 100644
index 000000000..1d1b6364f
--- /dev/null
+++ b/brainpy/synapses.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+
+from brainpy._src.dynold.synapses.base import (
+ SynConn as SynConn,
+ _SynSTP as SynSTP,
+ _SynOut as SynOut,
+ TwoEndConn as TwoEndConn,
+)
+from brainpy._src.dynold.synapses.biological_models import (
+ AMPA as AMPA,
+ GABAa as GABAa,
+ BioNMDA as BioNMDA,
+)
+from brainpy._src.dynold.synapses.abstract_models import (
+ Delta as Delta,
+ Exponential as Exponential,
+ DualExponential as DualExponential,
+ Alpha as Alpha,
+ NMDA as NMDA,
+)
+from brainpy._src.dynold.synapses.compat import (
+ DeltaSynapse as DeltaSynapse,
+ ExpCUBA as ExpCUBA,
+ ExpCOBA as ExpCOBA,
+ DualExpCUBA as DualExpCUBA,
+ DualExpCOBA as DualExpCOBA,
+ AlphaCUBA as AlphaCUBA,
+ AlphaCOBA as AlphaCOBA,
+)
+from brainpy._src.dynold.synapses.learning_rules import (
+ STP as STP,
+)
+
diff --git a/brainpy/synapses/__init__.py b/brainpy/synapses/__init__.py
deleted file mode 100644
index fba5a26c4..000000000
--- a/brainpy/synapses/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-
-from .dynamics import *
-from .synouts import *
-from .synplast import *
-
diff --git a/brainpy/synapses/dynamics.py b/brainpy/synapses/dynamics.py
deleted file mode 100644
index 59a8d41b5..000000000
--- a/brainpy/synapses/dynamics.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from brainpy._src.synapses.abstract_models import (
- Delta as Delta,
- Exponential as Exponential,
- DualExponential as DualExponential,
- Alpha as Alpha,
- NMDA as NMDA,
- PoissonInput as PoissonInput,
-)
-from brainpy._src.synapses.biological_models import (
- AMPA as AMPA,
- GABAa as GABAa,
- BioNMDA as BioNMDA,
-)
-from brainpy._src.synapses.delay_couplings import (
- DelayCoupling as DelayCoupling,
- DiffusiveCoupling as DiffusiveCoupling,
- AdditiveCoupling as AdditiveCoupling,
-)
-from brainpy._src.synapses.gap_junction import (
- GapJunction as GapJunction,
-)
-
-
diff --git a/brainpy/synapses/synouts.py b/brainpy/synapses/synouts.py
deleted file mode 100644
index c8be34142..000000000
--- a/brainpy/synapses/synouts.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from brainpy._src.synouts.conductances import (
- COBA as COBA,
- CUBA as CUBA,
-)
-from brainpy._src.synouts.ions import (
- MgBlock as MgBlock,
-)
-
diff --git a/brainpy/synapses/synplast.py b/brainpy/synapses/synplast.py
deleted file mode 100644
index fed0ab8b3..000000000
--- a/brainpy/synapses/synplast.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# -*- coding: utf-8 -*-
-
-from brainpy._src.synplast.short_term_plasticity import (
- STD as STD,
- STP as STP,
-)
diff --git a/brainpy/synouts.py b/brainpy/synouts.py
new file mode 100644
index 000000000..8e2b214c9
--- /dev/null
+++ b/brainpy/synouts.py
@@ -0,0 +1,10 @@
+# -*- coding: utf-8 -*-
+
+from brainpy._src.dynold.synouts.conductances import (
+ COBA as COBA,
+ CUBA as CUBA,
+)
+from brainpy._src.dynold.synouts.ions import (
+ MgBlock as MgBlock,
+)
+
diff --git a/brainpy/synplast.py b/brainpy/synplast.py
new file mode 100644
index 000000000..f551bc2cd
--- /dev/null
+++ b/brainpy/synplast.py
@@ -0,0 +1,6 @@
+# -*- coding: utf-8 -*-
+
+from brainpy._src.dynold.synplast.short_term_plasticity import (
+ STD as STD,
+ STP as STP,
+)
From 04eef834a7f642e8ed36c81b265e58d27b912c2a Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sat, 8 Jul 2023 22:32:34 +0800
Subject: [PATCH 015/326] update examples and tests
---
.github/workflows/CI-models.yml | 3 ---
examples/dynamics_analysis/3d_reduced_trn_model.py | 2 +-
examples/dynamics_simulation/COBA-v2.py | 8 ++++----
examples/dynamics_simulation/COBA.py | 4 ++--
examples/dynamics_training/echo_state_network.py | 4 ++--
examples/dynamics_training/integrate_flax_into_brainpy.py | 2 +-
examples/dynamics_training/integrator_rnn.py | 2 +-
examples/training_snn_models/spikebased_bp_for_cifar10.py | 6 +++---
tests/simulation/test_net_COBA.py | 4 ++--
tests/simulation/test_neu_HH.py | 2 +-
tests/training/test_ESN.py | 4 ++--
11 files changed, 19 insertions(+), 22 deletions(-)
diff --git a/.github/workflows/CI-models.yml b/.github/workflows/CI-models.yml
index 2b1af1111..2fef6aad2 100644
--- a/.github/workflows/CI-models.yml
+++ b/.github/workflows/CI-models.yml
@@ -16,7 +16,6 @@ on:
jobs:
test_linux:
runs-on: ubuntu-latest
- if: github.event.pull_request.merged == true
strategy:
fail-fast: false
matrix:
@@ -64,7 +63,6 @@ jobs:
test_macos:
runs-on: macos-latest
- if: github.event.pull_request.merged == true
strategy:
fail-fast: false
matrix:
@@ -113,7 +111,6 @@ jobs:
test_windows:
runs-on: windows-latest
- if: github.event.pull_request.merged == true
strategy:
fail-fast: false
matrix:
diff --git a/examples/dynamics_analysis/3d_reduced_trn_model.py b/examples/dynamics_analysis/3d_reduced_trn_model.py
index 247e91281..fde3da625 100644
--- a/examples/dynamics_analysis/3d_reduced_trn_model.py
+++ b/examples/dynamics_analysis/3d_reduced_trn_model.py
@@ -7,7 +7,7 @@
bp.math.set_platform('cpu')
-class ReducedTRNModel(bp.NeuGroup):
+class ReducedTRNModel(bp.NeuDyn):
def __init__(self, size, name=None, T=36., method='rk4'):
super(ReducedTRNModel, self).__init__(size=size, name=name)
diff --git a/examples/dynamics_simulation/COBA-v2.py b/examples/dynamics_simulation/COBA-v2.py
index 0a9077e66..4087cdc64 100644
--- a/examples/dynamics_simulation/COBA-v2.py
+++ b/examples/dynamics_simulation/COBA-v2.py
@@ -4,7 +4,7 @@
V_initializer=bp.init.Normal(-55., 2.))
-class EICOBA_PreAlign(bp.DynamicalSystemNS):
+class EICOBA_PreAlign(bp.DynamicalSystem):
def __init__(self, num_exc, num_inh, inp=20.):
super().__init__()
@@ -54,7 +54,7 @@ def update(self):
self.I(self.inp)
-class EICOBA_PostAlign(bp.DynamicalSystemNS):
+class EICOBA_PostAlign(bp.DynamicalSystem):
def __init__(self, num_exc, num_inh, inp=20.):
super().__init__()
self.inp = inp
@@ -165,5 +165,5 @@ def run2():
if __name__ == '__main__':
# run1()
- # run2()
- run3()
+ run2()
+ # run3()
diff --git a/examples/dynamics_simulation/COBA.py b/examples/dynamics_simulation/COBA.py
index 60cff2bb1..4818c3ab9 100644
--- a/examples/dynamics_simulation/COBA.py
+++ b/examples/dynamics_simulation/COBA.py
@@ -5,7 +5,7 @@
bm.set_host_device_count(20)
-class EINet(bp.DynamicalSystemNS):
+class EINet(bp.DynamicalSystem):
def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None):
super().__init__()
@@ -53,7 +53,7 @@ def update(self):
self.delayI(self.I(i_inp))
-class EINetv2(bp.DynamicalSystemNS):
+class EINetv2(bp.DynamicalSystem):
def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None):
super().__init__()
diff --git a/examples/dynamics_training/echo_state_network.py b/examples/dynamics_training/echo_state_network.py
index b87887d81..0aa816370 100644
--- a/examples/dynamics_training/echo_state_network.py
+++ b/examples/dynamics_training/echo_state_network.py
@@ -6,7 +6,7 @@
bm.set_environment(bm.batching_mode)
-class ESN(bp.DynamicalSystemNS):
+class ESN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden, num_out):
super(ESN, self).__init__()
self.r = bp.layers.Reservoir(num_in,
@@ -25,7 +25,7 @@ def update(self, x):
return x >> self.r >> self.o
-class NGRC(bp.DynamicalSystemNS):
+class NGRC(bp.DynamicalSystem):
def __init__(self, num_in, num_out):
super(NGRC, self).__init__()
diff --git a/examples/dynamics_training/integrate_flax_into_brainpy.py b/examples/dynamics_training/integrate_flax_into_brainpy.py
index 6e5795ca2..107e8b571 100644
--- a/examples/dynamics_training/integrate_flax_into_brainpy.py
+++ b/examples/dynamics_training/integrate_flax_into_brainpy.py
@@ -25,7 +25,7 @@ def __call__(self, x):
return x
-class Network(bp.DynamicalSystemNS):
+class Network(bp.DynamicalSystem):
def __init__(self):
super(Network, self).__init__()
self.cnn = bp.layers.FromFlax(CNN(), bm.ones([1, 4, 28, 1]))
diff --git a/examples/dynamics_training/integrator_rnn.py b/examples/dynamics_training/integrator_rnn.py
index ee04b19a4..fc36845e6 100644
--- a/examples/dynamics_training/integrator_rnn.py
+++ b/examples/dynamics_training/integrator_rnn.py
@@ -27,7 +27,7 @@ def train_data():
yield build_inputs_and_targets(batch_size=num_batch)
-class RNN(bp.DynamicalSystemNS):
+class RNN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden):
super(RNN, self).__init__()
self.rnn = bp.layers.RNNCell(num_in, num_hidden, train_state=True)
diff --git a/examples/training_snn_models/spikebased_bp_for_cifar10.py b/examples/training_snn_models/spikebased_bp_for_cifar10.py
index 91e98abb1..384360bba 100644
--- a/examples/training_snn_models/spikebased_bp_for_cifar10.py
+++ b/examples/training_snn_models/spikebased_bp_for_cifar10.py
@@ -41,7 +41,7 @@
help='number of data loading workers (default: 4)')
-class LIFNode(bp.DynamicalSystemNS):
+class LIFNode(bp.DynamicalSystem):
def __init__(self, size, tau=100.0, v_threshold=1.0, v_reset=0.0, fire: bool = True):
super().__init__()
bp.check.is_subclass(self.mode, [bp.math.TrainingMode, bp.math.BatchingMode])
@@ -93,7 +93,7 @@ def update(self, dv):
return self.v.value
-class IFNode(bp.DynamicalSystemNS):
+class IFNode(bp.DynamicalSystem):
def __init__(self, size, v_threshold=0.75, v_reset=0.0):
super().__init__()
bp.check.is_subclass(self.mode, [bm.TrainingMode, bm.BatchingMode])
@@ -121,7 +121,7 @@ def update(self, dv):
return spike
-class ResNet11(bp.DynamicalSystemNS):
+class ResNet11(bp.DynamicalSystem):
def __init__(self):
super().__init__()
diff --git a/tests/simulation/test_net_COBA.py b/tests/simulation/test_net_COBA.py
index 2cf49b402..941f233a0 100644
--- a/tests/simulation/test_net_COBA.py
+++ b/tests/simulation/test_net_COBA.py
@@ -4,7 +4,7 @@
show = False
-class EINet(bp.DynamicalSystemNS):
+class EINet(bp.DynamicalSystem):
def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None):
super().__init__()
@@ -52,7 +52,7 @@ def update(self):
self.delayI(self.I(i_inp))
-class EINetv2(bp.DynamicalSystemNS):
+class EINetv2(bp.DynamicalSystem):
def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None):
super().__init__()
diff --git a/tests/simulation/test_neu_HH.py b/tests/simulation/test_neu_HH.py
index 41575ecb1..0990733a4 100644
--- a/tests/simulation/test_neu_HH.py
+++ b/tests/simulation/test_neu_HH.py
@@ -12,7 +12,7 @@ def __init__(self, size):
self.IL = bp.channels.IL(size, E=-54.387, g_max=0.03)
-class HHv2(bp.NeuGroupNS):
+class HHv2(bp.NeuDyn):
def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03, V_th=20., C=1.0):
super().__init__(size=size)
diff --git a/tests/training/test_ESN.py b/tests/training/test_ESN.py
index a7485d40b..df36aa5f3 100644
--- a/tests/training/test_ESN.py
+++ b/tests/training/test_ESN.py
@@ -3,7 +3,7 @@
import unittest
-class ESN(bp.DynamicalSystemNS):
+class ESN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden, num_out):
super(ESN, self).__init__()
self.r = bp.layers.Reservoir(num_in,
@@ -22,7 +22,7 @@ def update(self, x):
return x >> self.r >> self.o
-class NGRC(bp.DynamicalSystemNS):
+class NGRC(bp.DynamicalSystem):
def __init__(self, num_in, num_out):
super(NGRC, self).__init__()
From 21f0783ada8b34ce866dcd851424f5f4739ed56c Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sun, 9 Jul 2023 14:46:39 +0800
Subject: [PATCH 016/326] fix
---
brainpy/__init__.py | 14 +-
brainpy/_add_deprecations.py | 8 +-
.../highdim/tests/test_slow_points.py | 2 +-
brainpy/_src/dnn/dropout.py | 7 +-
brainpy/_src/dyn/base.py | 24 +++
brainpy/_src/dyn/channels/base.py | 2 +-
brainpy/_src/dyn/channels/tests/test_Ca.py | 15 +-
brainpy/_src/dyn/channels/tests/test_IH.py | 7 +-
brainpy/_src/dyn/channels/tests/test_K.py | 19 ++-
brainpy/_src/dyn/channels/tests/test_KCa.py | 9 +-
brainpy/_src/dyn/channels/tests/test_Na.py | 12 +-
brainpy/_src/dyn/channels/tests/test_leaky.py | 11 +-
brainpy/_src/dyn/ions/base.py | 2 +-
brainpy/_src/dyn/ions/ca.py | 2 +-
brainpy/_src/dyn/neurons/base.py | 2 +-
brainpy/_src/dyn/neurons/hh.py | 3 +-
brainpy/_src/dyn/others/common.py | 2 +-
brainpy/_src/dyn/others/input.py | 13 +-
brainpy/_src/dyn/others/noise.py | 2 +-
.../dyn/others/tests/test_input_groups.py | 6 +-
.../dyn/others/tests/test_noise_groups.py | 3 +-
brainpy/_src/dyn/projections/aligns.py | 12 +-
brainpy/_src/dyn/projections/others.py | 9 +-
brainpy/_src/dyn/rates/populations.py | 2 +-
brainpy/_src/dyn/rates/tests/test_rates.py | 22 ++-
brainpy/_src/dyn/synapses/abstract_models.py | 2 +-
brainpy/_src/dyn/synapses/bio_models.py | 4 +-
brainpy/_src/dyn/synapses/delay_couplings.py | 31 ++--
brainpy/_src/dyn/synapses/gap_junction.py | 3 +-
.../{ => tests}/test_delay_couplings.py | 7 +-
.../synapses/{ => tests}/test_gap_junction.py | 3 +-
.../_src/dynold/neurons/biological_models.py | 32 +++-
.../_src/dynold/neurons/fractional_models.py | 2 +-
brainpy/_src/dynold/neurons/reduced_models.py | 67 ++++++--
.../_src/dynold/synapses/abstract_models.py | 6 +-
brainpy/_src/dynold/synapses/base.py | 6 +-
.../_src/dynold/synapses/biological_models.py | 2 +-
brainpy/_src/dynold/synapses/compat.py | 2 +-
.../_src/dynold/synapses/learning_rules.py | 3 +-
brainpy/_src/dynsys.py | 150 +++++++-----------
brainpy/_src/integrators/ode/exponential.py | 4 +-
.../ode/tests/test_ode_method_exp_euler.py | 2 +-
.../_src/math/event/tests/test_event_csrmv.py | 7 +-
brainpy/_src/math/jitconn/_event_matvec.py | 90 ++++-------
.../math/jitconn/tests/test_event_matvec.py | 111 ++++++-------
.../_src/math/jitconn/tests/test_matvec.py | 8 +-
.../tests/test_circular_reference.py | 2 +-
.../object_transform/tests/test_collector.py | 4 +-
.../tests/test_namechecking.py | 2 +-
brainpy/_src/math/sparse/tests/test_csrmv.py | 8 +-
brainpy/_src/mixin.py | 6 +-
...typing_copy.py => python_typing_copied.py} | 0
brainpy/_src/tests/test_access_methods.py | 2 +-
brainpy/_src/tests/test_dyn_runner.py | 3 +-
brainpy/_src/tests/test_mixin.py | 12 +-
brainpy/_src/tests/test_slice_view.py | 4 -
brainpy/dyn/__init__.py | 1 +
brainpy/dyn/base.py | 7 +
brainpy/dyn/channels.py | 1 +
brainpy/dyn/rates.py | 8 +
brainpy/dyn/synapses.py | 14 ++
brainpy/mixin.py | 1 +
brainpy/rates.py | 2 +
brainpy/synapses.py | 5 +
64 files changed, 466 insertions(+), 368 deletions(-)
create mode 100644 brainpy/_src/dyn/base.py
rename brainpy/_src/dyn/synapses/{ => tests}/test_delay_couplings.py (89%)
rename brainpy/_src/dyn/synapses/{ => tests}/test_gap_junction.py (90%)
rename brainpy/_src/{typing_copy.py => python_typing_copied.py} (100%)
create mode 100644 brainpy/dyn/base.py
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index d3c5f4e3e..efb4af83d 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -59,19 +59,18 @@
DynSysGroup as DynSysGroup, # collectors
Sequential as Sequential,
Network as Network,
- Dynamics as Dynamics, # dynamics
- NeuDyn as NeuDyn,
- SynDyn as SynDyn,
- IonChaDyn as IonChaDyn,
+ Dynamics as Dynamics, # category
+ Projection as Projection,
)
DynamicalSystemNS = DynamicalSystem
-NeuGroup = NeuGroupNS = NeuDyn
+
# building blocks
from brainpy import (
dnn, layers, # module for dnn layers
dyn, # module for modeling dynamics
)
+NeuGroup = NeuGroupNS = dyn.NeuDyn
# shared parameters
from brainpy._src.context import (share as share)
@@ -131,6 +130,11 @@
'Container': ('brainpy.Container', 'brainpy.DynSysGroup', DynSysGroup),
'optimizers': ('brainpy.optimizers', 'brainpy.optim', optim),
'TensorCollector': ('brainpy.TensorCollector', 'brainpy.ArrayCollector', ArrayCollector),
+ 'SynSTP': ('brainpy.SynSTP', 'brainpy.synapses.SynSTP', synapses.SynSTP),
+ 'SynOut': ('brainpy.SynOut', 'brainpy.synapses.SynOut', synapses.SynOut),
+ 'SynConn': ('brainpy.SynConn', 'brainpy.synapses.SynConn', synapses.SynConn),
+ 'TwoEndConn': ('brainpy.TwoEndConn', 'brainpy.synapses.TwoEndConn', synapses.TwoEndConn),
+ 'CondNeuGroup': ('brainpy.CondNeuGroup', 'brainpy.syn.CondNeuGroup', dyn.CondNeuGroup),
}
__getattr__ = deprecation_getattr2('brainpy', __deprecations)
diff --git a/brainpy/_add_deprecations.py b/brainpy/_add_deprecations.py
index f2f387cff..b7a477ae3 100644
--- a/brainpy/_add_deprecations.py
+++ b/brainpy/_add_deprecations.py
@@ -8,8 +8,8 @@
from brainpy._src.integrators.ode.generic import odeint
from brainpy._src.integrators.sde.generic import sdeint
from brainpy._src.integrators.fde.generic import fdeint
-from brainpy._src.dynsys import (DynamicalSystem, DynSysGroup, Sequential, Network,
- NeuDyn, Projection, IonChaDyn)
+from brainpy._src.dynsys import (DynamicalSystem, DynSysGroup, Sequential, Network)
+from brainpy._src.dyn.base import NeuDyn, IonChaDyn
from brainpy._src.runners import DSRunner
from brainpy._src.deprecations import deprecation_getattr2
@@ -55,6 +55,8 @@
synapses.__deprecations = {
'PoissonInput': ('brainpy.synapses.PoissonInput', 'brainpy.dyn.PoissonInput', dyn.PoissonInput),
+ 'DiffusiveCoupling': ('brainpy.synapses.DiffusiveCoupling', 'brainpy.dyn.DiffusiveCoupling', dyn.DiffusiveCoupling),
+ 'AdditiveCoupling': ('brainpy.synapses.AdditiveCoupling', 'brainpy.dyn.AdditiveCoupling', dyn.AdditiveCoupling),
}
synapses.__getattr__ = deprecation_getattr2('brainpy.synapses', synapses.__deprecations)
@@ -87,7 +89,7 @@
# synapses
'SynConn': ('brainpy.dyn.SynConn', 'brainpy.synapses.SynConn', synapses.SynConn),
# 'SynLTP': ('brainpy.dyn.SynLTP', 'brainpy.synapses.SynLTP', synapses.SynLTP),
- 'SynSTP': ('brainpy.dyn.SynSTP', 'brainpy.synapses.SynSTP', synapses._SynSTP),
+ 'SynSTP': ('brainpy.dyn.SynSTP', 'brainpy.synapses.SynSTP', synapses.SynSTP),
'TwoEndConn': ('brainpy.dyn.TwoEndConn', 'brainpy.synapses.TwoEndConn', synapses.TwoEndConn),
'DeltaSynapse': ('brainpy.dyn.DeltaSynapse', 'brainpy.synapses.Delta', synapses.DeltaSynapse),
'ExpCUBA': ('brainpy.dyn.ExpCUBA', 'brainpy.synapses.Exponential', synapses.ExpCUBA),
diff --git a/brainpy/_src/analysis/highdim/tests/test_slow_points.py b/brainpy/_src/analysis/highdim/tests/test_slow_points.py
index f4151cb85..9cf8f4fa8 100644
--- a/brainpy/_src/analysis/highdim/tests/test_slow_points.py
+++ b/brainpy/_src/analysis/highdim/tests/test_slow_points.py
@@ -5,7 +5,7 @@
import brainpy.math as bm
-class HH(bp.NeuDyn):
+class HH(bp.dyn.NeuDyn):
def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03,
V_th=20., C=1.0, name=None):
super(HH, self).__init__(size=size, name=name)
diff --git a/brainpy/_src/dnn/dropout.py b/brainpy/_src/dnn/dropout.py
index 80dbafdd4..dd60cc1df 100644
--- a/brainpy/_src/dnn/dropout.py
+++ b/brainpy/_src/dnn/dropout.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
+from typing import Optional
from brainpy._src.context import share
from brainpy import math as bm, check
@@ -16,7 +17,7 @@ class Dropout(Layer):
In training, to compensate for the fraction of input values dropped (`rate`),
all surviving values are multiplied by `1 / (1 - rate)`.
- This layer is active only during training (`mode=brainpy.modes.training`). In other
+ This layer is active only during training (``mode=brainpy.math.training_mode``). In other
circumstances it is a no-op.
.. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
@@ -33,8 +34,8 @@ class Dropout(Layer):
def __init__(
self,
prob: float,
- mode: bm.Mode = None,
- name: str = None
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None
):
super(Dropout, self).__init__(mode=mode, name=name)
self.prob = check.is_float(prob, min_bound=0., max_bound=1.)
diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py
new file mode 100644
index 000000000..c37504d47
--- /dev/null
+++ b/brainpy/_src/dyn/base.py
@@ -0,0 +1,24 @@
+# -*- coding: utf-8 -*-
+
+from brainpy._src.dynsys import Dynamics
+from brainpy._src.mixin import AutoDelaySupp, ParamDesc
+
+__all__ = [
+ 'NeuDyn', 'SynDyn', 'IonChaDyn',
+]
+
+
+class NeuDyn(Dynamics, AutoDelaySupp):
+ """Neuronal Dynamics."""
+ pass
+
+
+class SynDyn(Dynamics, AutoDelaySupp, ParamDesc):
+ """Synaptic Dynamics."""
+ pass
+
+
+class IonChaDyn(Dynamics):
+ """Ion Channel Dynamics."""
+ pass
+
diff --git a/brainpy/_src/dyn/channels/base.py b/brainpy/_src/dyn/channels/base.py
index db2d9700d..863bbd7d4 100644
--- a/brainpy/_src/dyn/channels/base.py
+++ b/brainpy/_src/dyn/channels/base.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-from brainpy._src.dynsys import IonChaDyn
+from brainpy._src.dyn.base import IonChaDyn
from brainpy._src.mixin import TreeNode
from brainpy._src.dyn.ions.base import Calcium
from brainpy._src.dyn.neurons.hh import HHTypedNeuron
diff --git a/brainpy/_src/dyn/channels/tests/test_Ca.py b/brainpy/_src/dyn/channels/tests/test_Ca.py
index 2ffe1a983..0b7593f7b 100644
--- a/brainpy/_src/dyn/channels/tests/test_Ca.py
+++ b/brainpy/_src/dyn/channels/tests/test_Ca.py
@@ -4,12 +4,11 @@
import brainpy as bp
import brainpy.math as bm
from absl.testing import parameterized
-from brainpy._src.dyn.channels import Ca
class Test_Ca(parameterized.TestCase):
def test_Ca(self):
- class Neuron(bp.CondNeuGroup):
+ class Neuron(bp.dyn.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
self.Ca1 = bp.dyn.CalciumFixed(size)
@@ -29,7 +28,7 @@ def __init__(self, size):
def test_ICaN_IS2008(self):
bm.random.seed(1234)
- class Neuron(bp.CondNeuGroup):
+ class Neuron(bp.dyn.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
self.Ca = bp.dyn.CalciumDetailed(size,
@@ -47,7 +46,7 @@ def __init__(self, size):
def test_ICaT_HM1992(self):
bm.random.seed(1234)
- class Neuron(bp.CondNeuGroup):
+ class Neuron(bp.dyn.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
self.Ca = bp.dyn.CalciumDetailed(size,
@@ -67,7 +66,7 @@ def __init__(self, size):
def test_ICaT_HP1992(self):
bm.random.seed(1234)
- class Neuron(bp.CondNeuGroup):
+ class Neuron(bp.dyn.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
self.Ca = bp.dyn.CalciumDetailed(size,
@@ -87,7 +86,7 @@ def __init__(self, size):
def test_ICaHT_HM1992(self):
bm.random.seed(1234)
- class Neuron(bp.CondNeuGroup):
+ class Neuron(bp.dyn.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
self.Ca = bp.dyn.CalciumDetailed(size,
@@ -107,7 +106,7 @@ def __init__(self, size):
def test_ICaHT_Re1993(self):
bm.random.seed(1234)
- class Neuron(bp.CondNeuGroup):
+ class Neuron(bp.dyn.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
self.Ca = bp.dyn.CalciumDetailed(size,
@@ -127,7 +126,7 @@ def __init__(self, size):
def test_ICaL_IS2008(self):
bm.random.seed(1234)
- class Neuron(bp.CondNeuGroup):
+ class Neuron(bp.dyn.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
self.Ca = bp.dyn.CalciumDetailed(size,
diff --git a/brainpy/_src/dyn/channels/tests/test_IH.py b/brainpy/_src/dyn/channels/tests/test_IH.py
index f4e589a0d..5860a9cdd 100644
--- a/brainpy/_src/dyn/channels/tests/test_IH.py
+++ b/brainpy/_src/dyn/channels/tests/test_IH.py
@@ -4,17 +4,16 @@
import brainpy as bp
import brainpy.math as bm
from absl.testing import parameterized
-from brainpy._src.dyn.channels import IH, Ca
class Test_IH(parameterized.TestCase):
bm.random.seed(1234)
def test_IH(self):
- class Neuron(bp.CondNeuGroup):
+ class Neuron(bp.dyn.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size)
- self.IH = IH.Ih_HM1992(size)
- self.Ca = Ca.CalciumDetailed(size, IH=IH.Ih_De1996(size))
+ self.IH = bp.dyn.Ih_HM1992(size)
+ self.Ca = bp.dyn.CalciumDetailed(size, IH=bp.dyn.Ih_De1996(size))
model = Neuron(1)
runner = bp.DSRunner(model,
diff --git a/brainpy/_src/dyn/channels/tests/test_K.py b/brainpy/_src/dyn/channels/tests/test_K.py
index 1fc625b90..2bdd63bde 100644
--- a/brainpy/_src/dyn/channels/tests/test_K.py
+++ b/brainpy/_src/dyn/channels/tests/test_K.py
@@ -4,22 +4,21 @@
import brainpy as bp
import brainpy.math as bm
from absl.testing import parameterized
-from brainpy._src.dyn.channels import K
class Test_K(parameterized.TestCase):
bm.random.seed(1234)
def test_K(self):
- class Neuron(bp.CondNeuGroup):
+ class Neuron(bp.dyn.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.))
- self.IK_1 = K.IKDR_Ba2002(size)
- self.IK_2 = K.IK_TM1991(size)
- self.IK_3 = K.IK_HH1952(size)
- self.IK_4 = K.IKA1_HM1992(size)
- self.IK_5 = K.IKA2_HM1992(size)
- self.IK_6 = K.IKK2A_HM1992(size)
- self.IK_7 = K.IKK2B_HM1992(size)
- self.IK_8 = K.IKNI_Ya1989(size)
+ self.IK_1 = bp.dyn.IKDR_Ba2002(size)
+ self.IK_2 = bp.dyn.IK_TM1991(size)
+ self.IK_3 = bp.dyn.IK_HH1952(size)
+ self.IK_4 = bp.dyn.IKA1_HM1992(size)
+ self.IK_5 = bp.dyn.IKA2_HM1992(size)
+ self.IK_6 = bp.dyn.IKK2A_HM1992(size)
+ self.IK_7 = bp.dyn.IKK2B_HM1992(size)
+ self.IK_8 = bp.dyn.IKNI_Ya1989(size)
model = Neuron(1)
runner = bp.DSRunner(model,
diff --git a/brainpy/_src/dyn/channels/tests/test_KCa.py b/brainpy/_src/dyn/channels/tests/test_KCa.py
index d422dc28a..ad52c0871 100644
--- a/brainpy/_src/dyn/channels/tests/test_KCa.py
+++ b/brainpy/_src/dyn/channels/tests/test_KCa.py
@@ -4,15 +4,16 @@
import brainpy as bp
import brainpy.math as bm
from absl.testing import parameterized
-from brainpy._src.dyn.channels import KCa, Ca
+
class Test_KCa(parameterized.TestCase):
bm.random.seed(1234)
+
def test_KCa(self):
- class Neuron(bp.CondNeuGroup):
+ class Neuron(bp.dyn.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.))
- self.Ca = Ca.CalciumDetailed(size, KCa=KCa.IAHP_De1994(size))
+ self.Ca = bp.dyn.CalciumDetailed(size, KCa=bp.dyn.IAHP_De1994(size))
model = Neuron(1)
runner = bp.DSRunner(model,
@@ -20,4 +21,4 @@ def __init__(self, size):
progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon['V'].shape, (100, 1))
- self.assertTupleEqual(runner.mon['Ca.KCa.p'].shape, (100, 1))
\ No newline at end of file
+ self.assertTupleEqual(runner.mon['Ca.KCa.p'].shape, (100, 1))
diff --git a/brainpy/_src/dyn/channels/tests/test_Na.py b/brainpy/_src/dyn/channels/tests/test_Na.py
index f2112162f..58002e3f0 100644
--- a/brainpy/_src/dyn/channels/tests/test_Na.py
+++ b/brainpy/_src/dyn/channels/tests/test_Na.py
@@ -4,18 +4,18 @@
import brainpy as bp
import brainpy.math as bm
from absl.testing import parameterized
-from brainpy._src.dyn.channels import Na
class Test_Na(parameterized.TestCase):
bm.random.seed(1234)
+
def test_Na(self):
- class Neuron(bp.CondNeuGroup):
+ class Neuron(bp.dyn.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.))
- self.INa_1 = Na.INa_HH1952(size, E=50., g_max=120.)
- self.INa_2 = Na.INa_TM1991(size)
- self.INa_3 = Na.INa_Ba2002(size)
+ self.INa_1 = bp.dyn.INa_HH1952(size, E=50., g_max=120.)
+ self.INa_2 = bp.dyn.INa_TM1991(size)
+ self.INa_3 = bp.dyn.INa_Ba2002(size)
model = Neuron(1)
runner = bp.DSRunner(model,
@@ -29,5 +29,3 @@ def __init__(self, size):
self.assertTupleEqual(runner.mon['INa_2.q'].shape, (100, 1))
self.assertTupleEqual(runner.mon['INa_3.p'].shape, (100, 1))
self.assertTupleEqual(runner.mon['INa_3.q'].shape, (100, 1))
-
-
diff --git a/brainpy/_src/dyn/channels/tests/test_leaky.py b/brainpy/_src/dyn/channels/tests/test_leaky.py
index 341e7c213..9535cefde 100644
--- a/brainpy/_src/dyn/channels/tests/test_leaky.py
+++ b/brainpy/_src/dyn/channels/tests/test_leaky.py
@@ -4,20 +4,21 @@
import brainpy as bp
import brainpy.math as bm
from absl.testing import parameterized
-from brainpy._src.dyn.channels import leaky
+
class Test_Leaky(parameterized.TestCase):
bm.random.seed(1234)
+
def test_leaky(self):
- class Neuron(bp.CondNeuGroup):
+ class Neuron(bp.dyn.CondNeuGroup):
def __init__(self, size):
super(Neuron, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.))
- self.leaky1 = leaky.IL(size)
- self.leaky2 = leaky.IKL(size)
+ self.leaky1 = bp.dyn.IL(size)
+ self.leaky2 = bp.dyn.IKL(size)
model = Neuron(1)
runner = bp.DSRunner(model,
monitors=['V'],
progress_bar=False)
runner.run(10.)
- self.assertTupleEqual(runner.mon['V'].shape, (100, 1))
\ No newline at end of file
+ self.assertTupleEqual(runner.mon['V'].shape, (100, 1))
diff --git a/brainpy/_src/dyn/ions/base.py b/brainpy/_src/dyn/ions/base.py
index 2b260c03c..bee8c08c2 100644
--- a/brainpy/_src/dyn/ions/base.py
+++ b/brainpy/_src/dyn/ions/base.py
@@ -4,7 +4,7 @@
import brainpy.math as bm
from brainpy._src.dyn.neurons.hh import CondNeuGroup
-from brainpy._src.dynsys import IonChaDyn
+from brainpy._src.dyn.base import IonChaDyn
from brainpy._src.mixin import Container, TreeNode
from brainpy.types import Shape
diff --git a/brainpy/_src/dyn/ions/ca.py b/brainpy/_src/dyn/ions/ca.py
index 29a5b8a2e..89bc2d2d1 100644
--- a/brainpy/_src/dyn/ions/ca.py
+++ b/brainpy/_src/dyn/ions/ca.py
@@ -4,7 +4,7 @@
import brainpy.math as bm
from brainpy._src.context import share
-from brainpy._src.dynsys import IonChaDyn
+from brainpy._src.dyn.base import IonChaDyn
from brainpy._src.initialize import OneInit, Initializer, parameter, variable
from brainpy._src.integrators.ode.generic import odeint
from brainpy.types import Shape, ArrayType
diff --git a/brainpy/_src/dyn/neurons/base.py b/brainpy/_src/dyn/neurons/base.py
index bfe75c155..de4317a83 100644
--- a/brainpy/_src/dyn/neurons/base.py
+++ b/brainpy/_src/dyn/neurons/base.py
@@ -2,7 +2,7 @@
import brainpy.math as bm
from brainpy._src.dyn._docs import pneu_doc, dpneu_doc
-from brainpy._src.dynsys import NeuDyn
+from brainpy._src.dyn.base import NeuDyn
from brainpy.check import is_callable
__all__ = ['GradNeuDyn']
diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py
index cbfeb69fa..482a3ac91 100644
--- a/brainpy/_src/dyn/neurons/hh.py
+++ b/brainpy/_src/dyn/neurons/hh.py
@@ -4,7 +4,8 @@
import brainpy.math as bm
from brainpy._src.context import share
-from brainpy._src.dynsys import NeuDyn, IonChaDyn, DynamicalSystem
+from brainpy._src.dynsys import DynamicalSystem
+from brainpy._src.dyn.base import NeuDyn, IonChaDyn
from brainpy._src.initialize import OneInit
from brainpy._src.initialize import Uniform, variable_, noise as init_noise
from brainpy._src.integrators import JointEq
diff --git a/brainpy/_src/dyn/others/common.py b/brainpy/_src/dyn/others/common.py
index 418cb6ad1..ef069d4ea 100644
--- a/brainpy/_src/dyn/others/common.py
+++ b/brainpy/_src/dyn/others/common.py
@@ -5,7 +5,7 @@
from brainpy._src import tools
from brainpy._src.context import share
from brainpy._src.dyn._docs import pneu_doc
-from brainpy._src.dynsys import NeuDyn
+from brainpy._src.dyn.base import NeuDyn
from brainpy._src.integrators import odeint
from brainpy.check import is_initializer
from brainpy.types import ArrayType
diff --git a/brainpy/_src/dyn/others/input.py b/brainpy/_src/dyn/others/input.py
index 041f8b59f..0bf8a2b76 100644
--- a/brainpy/_src/dyn/others/input.py
+++ b/brainpy/_src/dyn/others/input.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-
+import warnings
from functools import partial
from typing import Union, Sequence, Any, Optional, Callable
@@ -9,7 +9,7 @@
from brainpy import math as bm
from brainpy._src.context import share
from brainpy._src.dyn.utils import get_spk_type
-from brainpy._src.dynsys import NeuDyn
+from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import parameter, variable_
from brainpy._src.mixin import ReturnInfo
from brainpy.types import Shape, ArrayType
@@ -165,7 +165,8 @@ def reset_state(self, batch_size=None):
batch_axis_name=bm.sharding.BATCH_AXIS)
def update(self):
- self.spike.value = bm.sharding.partition(bm.zeros_like(self.spike), self.spike.sharding)
+ # self.spike.value = bm.sharding.partition(bm.zeros_like(self.spike), self.spike.sharding)
+ self.spike.value = bm.zeros_like(self.spike)
bm.while_loop(self._body_fun, self._cond_fun, ())
return self.spike.value
@@ -199,6 +200,7 @@ def __init__(
spk_type: Optional[type] = None,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
+ seed=None,
):
super(PoissonGroup, self).__init__(size=size,
sharding=sharding,
@@ -206,6 +208,9 @@ def __init__(
keep_size=keep_size,
mode=mode)
+ if seed is not None:
+ warnings.warn('')
+
# parameters
self.freqs = parameter(freqs, self.num, allow_none=False)
self.spk_type = get_spk_type(spk_type, self.mode)
@@ -216,7 +221,7 @@ def __init__(
def update(self):
spikes = bm.random.rand_like(self.spike) <= (self.freqs * share.dt / 1000.)
spikes = bm.asarray(spikes, dtype=self.spk_type)
- spikes = bm.sharding.partition(spikes, self.spike.sharding)
+ # spikes = bm.sharding.partition(spikes, self.spike.sharding)
self.spike.value = spikes
return spikes
diff --git a/brainpy/_src/dyn/others/noise.py b/brainpy/_src/dyn/others/noise.py
index 255d3f1f1..50db2f4dd 100644
--- a/brainpy/_src/dyn/others/noise.py
+++ b/brainpy/_src/dyn/others/noise.py
@@ -4,7 +4,7 @@
import brainpy.math as bm
from brainpy._src.context import share
-from brainpy._src.dynsys import NeuDyn
+from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import variable_, parameter
from brainpy._src.integrators.sde.generic import sdeint
from brainpy.types import Shape, ArrayType
diff --git a/brainpy/_src/dyn/others/tests/test_input_groups.py b/brainpy/_src/dyn/others/tests/test_input_groups.py
index 1028bcc8e..352babde3 100644
--- a/brainpy/_src/dyn/others/tests/test_input_groups.py
+++ b/brainpy/_src/dyn/others/tests/test_input_groups.py
@@ -3,13 +3,13 @@
import brainpy as bp
from absl.testing import parameterized
-from brainpy._src.neurons import input_groups
+from brainpy._src.dyn.others import input
class Test_input_Group(parameterized.TestCase):
def test_SpikeTimeGroup(self):
bp.math.random.seed()
- model = input_groups.SpikeTimeGroup(size=2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
+ model = input.SpikeTimeGroup(size=2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
runner = bp.DSRunner(model,
monitors=['spike'],
progress_bar=False)
@@ -19,7 +19,7 @@ def test_SpikeTimeGroup(self):
def test_PoissonGroup(self):
bp.math.random.seed()
- model = input_groups.PoissonGroup(size=2, freqs=1000, seed=0)
+ model = input.PoissonGroup(size=2, freqs=1000)
runner = bp.DSRunner(model,
monitors=['spike'],
progress_bar=False)
diff --git a/brainpy/_src/dyn/others/tests/test_noise_groups.py b/brainpy/_src/dyn/others/tests/test_noise_groups.py
index 2fc831e61..d93657c89 100644
--- a/brainpy/_src/dyn/others/tests/test_noise_groups.py
+++ b/brainpy/_src/dyn/others/tests/test_noise_groups.py
@@ -4,13 +4,12 @@
import brainpy as bp
import brainpy.math as bm
from absl.testing import parameterized
-from brainpy._src.neurons import noise_groups
class Test_Noise_Group(parameterized.TestCase):
def test_OU(self):
bm.random.seed(1234)
- model = noise_groups.OUProcess(size=1, mean=0., sigma=0.1)
+ model = bp.dyn.OUProcess(size=1, mean=0., sigma=0.1)
runner = bp.DSRunner(model,
monitors=['x'],
progress_bar=False)
diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py
index 7ad9535c9..7d0f7395b 100644
--- a/brainpy/_src/dyn/projections/aligns.py
+++ b/brainpy/_src/dyn/projections/aligns.py
@@ -2,8 +2,8 @@
from brainpy import math as bm
from brainpy._src.delay import Delay, VariableDelay, DataDelay
-from brainpy._src.dynsys import DynamicalSystem, Projection, NeuDyn
-from brainpy._src.mixin import JointType, ParamDesc, ParamDescInit, ReturnInfo, AutoDelaySupp, BindCondData, AlignPost
+from brainpy._src.dynsys import DynamicalSystem, Projection, Dynamics
+from brainpy._src.mixin import JointType, ParamDescInit, ReturnInfo, AutoDelaySupp, BindCondData, AlignPost
__all__ = [
'ProjAlignPre',
@@ -81,7 +81,7 @@ def __init__(
delay: Union[None, int, float],
comm: Callable,
out: JointType[DynamicalSystem, BindCondData],
- post: NeuDyn,
+ post: Dynamics,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
@@ -92,7 +92,7 @@ def __init__(
assert isinstance(syn, ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]])
assert isinstance(comm, Callable)
assert isinstance(out, JointType[DynamicalSystem, BindCondData])
- assert isinstance(post, NeuDyn)
+ assert isinstance(post, Dynamics)
self.pre = pre
self.post = post
self.comm = comm
@@ -140,7 +140,7 @@ def __init__(
comm: Callable,
syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]],
out: ParamDescInit[JointType[DynamicalSystem, BindCondData]],
- post: NeuDyn,
+ post: Dynamics,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
@@ -151,7 +151,7 @@ def __init__(
assert isinstance(comm, Callable)
assert isinstance(syn, ParamDescInit[JointType[DynamicalSystem, AlignPost]])
assert isinstance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]])
- assert isinstance(post, NeuDyn)
+ assert isinstance(post, Dynamics)
self.pre = pre
self.post = post
self.comm = comm
diff --git a/brainpy/_src/dyn/projections/others.py b/brainpy/_src/dyn/projections/others.py
index 506382e2e..44cdfb043 100644
--- a/brainpy/_src/dyn/projections/others.py
+++ b/brainpy/_src/dyn/projections/others.py
@@ -1,4 +1,5 @@
import numbers
+import warnings
from typing import Union, Optional
from brainpy import check, math as bm
@@ -37,10 +38,14 @@ def __init__(
freq: Union[int, float],
weight: Union[int, float],
mode: Optional[bm.Mode] = None,
- name: Optional[str] = None
+ name: Optional[str] = None,
+ seed=None
):
super().__init__(name=name, mode=mode)
+ if seed is not None:
+ warnings.warn('')
+
if not isinstance(target_var, bm.Variable):
raise TypeError(f'"target_var" must be an instance of Variable. '
f'But we got {type(target_var)}: {target_var}')
@@ -66,7 +71,7 @@ def update(self):
lambda: bm.random.binomial(self.num_input, p, self.target_var.shape),
())
- inp = bm.sharding.partition(inp, self.target_var.sharding)
+ # inp = bm.sharding.partition(inp, self.target_var.sharding)
self.target_var += inp * self.weight
def __repr__(self):
diff --git a/brainpy/_src/dyn/rates/populations.py b/brainpy/_src/dyn/rates/populations.py
index afea3c4b2..9ce83e144 100644
--- a/brainpy/_src/dyn/rates/populations.py
+++ b/brainpy/_src/dyn/rates/populations.py
@@ -7,7 +7,7 @@
from brainpy import math as bm
from brainpy._src.context import share
from brainpy._src.dyn.others.noise import OUProcess
-from brainpy._src.dynsys import NeuDyn
+from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import (Initializer,
Uniform,
parameter,
diff --git a/brainpy/_src/dyn/rates/tests/test_rates.py b/brainpy/_src/dyn/rates/tests/test_rates.py
index 88c016705..4ae162b8f 100644
--- a/brainpy/_src/dyn/rates/tests/test_rates.py
+++ b/brainpy/_src/dyn/rates/tests/test_rates.py
@@ -2,6 +2,7 @@
import brainpy as bp
+import brainpy.math as bm
from absl.testing import parameterized
from brainpy._src.dyn.rates import populations
from unittest import TestCase
@@ -9,26 +10,32 @@
class TestRate(TestCase):
def test_fhn(self):
+ bm.random.seed()
fhn = bp.rates.FHN(10)
self.assertTrue(fhn.tau is not None)
def test_ffhn(self):
+ bm.random.seed()
ffhn = bp.rates.FeedbackFHN(size=1)
self.assertTrue(ffhn.tau is not None)
def test_qif(self):
+ bm.random.seed()
qif = bp.rates.QIF(size=1)
self.assertTrue(qif.tau is not None)
def test_slo(self):
+ bm.random.seed()
slo = bp.rates.StuartLandauOscillator(size=1)
self.assertTrue(slo.x_ou_tau is not None)
def test_wcm(self):
+ bm.random.seed()
wcm = bp.rates.WilsonCowanModel(size=1)
self.assertTrue(wcm.x_ou_tau is not None)
def test_tlm(self):
+ bm.random.seed()
tlm = bp.rates.ThresholdLinearModel(size=1)
self.assertTrue(tlm.tau_e is not None)
@@ -39,47 +46,60 @@ class TestPopulation(parameterized.TestCase):
for name in populations.__all__
)
def test_runner(self, neuron):
+ bm.random.seed()
model = getattr(populations, neuron)(size=10)
runner = bp.DSRunner(model, progress_bar=False)
runner.run(10.)
+ bm.clear_buffer_memory()
class TestShape(parameterized.TestCase):
def test_FHN_shape(self):
+ bm.random.seed()
model = getattr(populations, 'FHN')(size=10)
runner = bp.DSRunner(model,
monitors=['x'],
progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon.x.shape, (100, 10))
+ bm.clear_buffer_memory()
def test_FFHN_shape(self):
+ bm.random.seed()
model = getattr(populations, 'FeedbackFHN')(size=10)
runner = bp.DSRunner(model,
monitors=['x'],
progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon.x.shape, (100, 10))
+ bm.clear_buffer_memory()
def test_QIF_shape(self):
+ bm.random.seed()
model = getattr(populations, 'QIF')(size=10)
runner = bp.DSRunner(model,
monitors=['x'],
progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon.x.shape, (100, 10))
+ bm.clear_buffer_memory()
def test_SLO_shape(self):
+ bm.random.seed()
model = getattr(populations, 'StuartLandauOscillator')(size=10)
runner = bp.DSRunner(model,
monitors=['x'],
progress_bar=False)
runner.run(10.)
self.assertTupleEqual(runner.mon.x.shape, (100, 10))
+ bm.clear_buffer_memory()
def test_TLM_shape(self):
+ bm.random.seed()
model = getattr(populations, 'ThresholdLinearModel')(size=10)
runner = bp.DSRunner(model,
monitors=['e'],
progress_bar=False)
runner.run(10.)
- self.assertTupleEqual(runner.mon.e.shape, (100, 10))
\ No newline at end of file
+ self.assertTupleEqual(runner.mon.e.shape, (100, 10))
+ bm.clear_buffer_memory()
+
diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/dyn/synapses/abstract_models.py
index 421cc086c..cd8162f58 100644
--- a/brainpy/_src/dyn/synapses/abstract_models.py
+++ b/brainpy/_src/dyn/synapses/abstract_models.py
@@ -4,7 +4,7 @@
from brainpy import math as bm
from brainpy._src.context import share
from brainpy._src.dyn._docs import pneu_doc
-from brainpy._src.dynsys import SynDyn
+from brainpy._src.dyn.base import SynDyn
from brainpy._src.integrators.joint_eq import JointEq
from brainpy._src.integrators.ode.generic import odeint
from brainpy._src.mixin import AlignPost, ReturnInfo
diff --git a/brainpy/_src/dyn/synapses/bio_models.py b/brainpy/_src/dyn/synapses/bio_models.py
index fd182380a..5e1866a66 100644
--- a/brainpy/_src/dyn/synapses/bio_models.py
+++ b/brainpy/_src/dyn/synapses/bio_models.py
@@ -4,11 +4,9 @@
from brainpy import math as bm
from brainpy._src.context import share
from brainpy._src.dyn._docs import pneu_doc
-from brainpy._src.dynsys import SynDyn
+from brainpy._src.dyn.base import SynDyn
from brainpy._src.integrators.joint_eq import JointEq
from brainpy._src.integrators.ode.generic import odeint
-from brainpy._src.mixin import AlignPost, ReturnInfo
-from brainpy._src.initialize import Constant
from brainpy.types import ArrayType
__all__ = [
diff --git a/brainpy/_src/dyn/synapses/delay_couplings.py b/brainpy/_src/dyn/synapses/delay_couplings.py
index 4ce50c3ee..a4ecaa67c 100644
--- a/brainpy/_src/dyn/synapses/delay_couplings.py
+++ b/brainpy/_src/dyn/synapses/delay_couplings.py
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
+import numbers
from typing import Optional, Union, Sequence, Tuple, Callable
import jax.numpy as jnp
from jax import vmap
import brainpy.math as bm
-from brainpy._src.dynsys import DynSysGroup as SynConn
-from brainpy._src.neurons.input_groups import InputGroup, OutputGroup
+from brainpy._src.dynsys import Projection
from brainpy._src.initialize import Initializer
from brainpy.check import is_sequence
from brainpy.types import ArrayType
@@ -19,7 +19,7 @@
]
-class DelayCoupling(SynConn):
+class DelayCoupling(Projection):
"""Delay coupling.
Parameters
@@ -44,15 +44,12 @@ def __init__(
var_to_output: Union[bm.Variable, Sequence[bm.Variable]],
conn_mat: ArrayType,
required_shape: Tuple[int, ...],
- delay_steps: Optional[Union[int, ArrayType, Initializer, Callable]] = None,
- initial_delay_data: Union[Initializer, Callable, ArrayType, float, int, bool] = None,
- name: str = None,
- mode: bm.Mode = None,
+ delay_steps: Optional[Union[int, ArrayType, Callable]] = None,
+ initial_delay_data: Union[Callable, ArrayType, numbers.Number] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
):
- super(DelayCoupling, self).__init__(name=name,
- mode=mode,
- pre=InputGroup(1),
- post=OutputGroup(1))
+ super().__init__(name=name, mode=mode)
# delay variable
if not isinstance(delay_var, bm.Variable):
@@ -177,7 +174,7 @@ def __init__(
raise ValueError(f'Only support 1d vector of coupling variable. '
f'But we got {jnp.ndim(coupling_var2)}')
- super(DiffusiveCoupling, self).__init__(
+ super().__init__(
delay_var=coupling_var1,
var_to_output=var_to_output,
conn_mat=conn_mat,
@@ -191,10 +188,10 @@ def __init__(
self.coupling_var1 = coupling_var1
self.coupling_var2 = coupling_var2
- def update(self, tdi):
+ def update(self):
# delays
axis = self.coupling_var1.ndim
- delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0]
+ delay_var: bm.LengthDelay = self.get_delay_var(f'delay_{id(self.delay_var)}')[0]
if self.delay_steps is None:
diffusive = (jnp.expand_dims(self.coupling_var1.value, axis=axis) -
jnp.expand_dims(self.coupling_var2.value, axis=axis - 1))
@@ -263,7 +260,7 @@ def __init__(
raise ValueError(f'Only support 1d vector of coupling variable. '
f'But we got {jnp.ndim(coupling_var)}')
- super(AdditiveCoupling, self).__init__(
+ super().__init__(
delay_var=coupling_var,
var_to_output=var_to_output,
conn_mat=conn_mat,
@@ -276,10 +273,10 @@ def __init__(
self.coupling_var = coupling_var
- def update(self, tdi):
+ def update(self):
# delay function
axis = self.coupling_var.ndim
- delay_var: bm.LengthDelay = self.global_delay_data[f'delay_{id(self.delay_var)}'][0]
+ delay_var: bm.LengthDelay = self.get_delay_var(f'delay_{id(self.delay_var)}')[0]
if self.delay_steps is None:
additive = self.coupling_var @ self.conn_mat
elif self.delay_type == 'array':
diff --git a/brainpy/_src/dyn/synapses/gap_junction.py b/brainpy/_src/dyn/synapses/gap_junction.py
index c9432d3b0..c37903fc5 100644
--- a/brainpy/_src/dyn/synapses/gap_junction.py
+++ b/brainpy/_src/dyn/synapses/gap_junction.py
@@ -3,8 +3,9 @@
from typing import Union, Dict, Callable
import brainpy.math as bm
+from brainpy._src.dyn.base import NeuDyn
from brainpy._src.connect import TwoEndConnector
-from brainpy._src.dynsys import NeuDyn, DynamicalSystem as TwoEndConn
+from brainpy._src.dynold.synapses import TwoEndConn
from brainpy._src.initialize import Initializer, parameter
from brainpy.types import ArrayType
diff --git a/brainpy/_src/dyn/synapses/test_delay_couplings.py b/brainpy/_src/dyn/synapses/tests/test_delay_couplings.py
similarity index 89%
rename from brainpy/_src/dyn/synapses/test_delay_couplings.py
rename to brainpy/_src/dyn/synapses/tests/test_delay_couplings.py
index 51af9d685..f6099abbd 100644
--- a/brainpy/_src/dyn/synapses/test_delay_couplings.py
+++ b/brainpy/_src/dyn/synapses/tests/test_delay_couplings.py
@@ -1,11 +1,10 @@
# -*- coding: utf-8 -*-
+from absl.testing import parameterized
+
import brainpy as bp
import brainpy.math as bm
-from brainpy import rates
-from absl.testing import parameterized
-from brainpy._src.synapses import delay_couplings
class Test_delay_couplings(parameterized.TestCase):
@@ -14,7 +13,7 @@ def test_DiffusiveCoupling(self):
areas = bp.rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn1')
conn = bp.synapses.DiffusiveCoupling(areas.x, areas.x, areas.input,
conn_mat=bp.conn.All2All(pre=areas.num, post=areas.num).require('conn_mat'),
- initial_delay_data = bp.init.Uniform(0, 0.05))
+ initial_delay_data=bp.init.Uniform(0, 0.05))
net = bp.Network(areas, conn)
# 运行模拟
diff --git a/brainpy/_src/dyn/synapses/test_gap_junction.py b/brainpy/_src/dyn/synapses/tests/test_gap_junction.py
similarity index 90%
rename from brainpy/_src/dyn/synapses/test_gap_junction.py
rename to brainpy/_src/dyn/synapses/tests/test_gap_junction.py
index c3ff9440b..8ef37459a 100644
--- a/brainpy/_src/dyn/synapses/test_gap_junction.py
+++ b/brainpy/_src/dyn/synapses/tests/test_gap_junction.py
@@ -3,9 +3,8 @@
import brainpy as bp
import brainpy.math as bm
-from brainpy import rates
from absl.testing import parameterized
-from brainpy._src.synapses import gap_junction
+from brainpy._src.dyn.synapses import gap_junction
class Test_gap_junction(parameterized.TestCase):
diff --git a/brainpy/_src/dynold/neurons/biological_models.py b/brainpy/_src/dynold/neurons/biological_models.py
index 2adad502c..0ea235296 100644
--- a/brainpy/_src/dynold/neurons/biological_models.py
+++ b/brainpy/_src/dynold/neurons/biological_models.py
@@ -6,7 +6,7 @@
from brainpy import check
from brainpy._src.context import share
from brainpy._src.dyn.neurons import hh
-from brainpy._src.dynsys import NeuDyn
+from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import (OneInit,
Initializer,
parameter,
@@ -198,10 +198,18 @@ class HH(hh.HH):
"""
def __init__(
- self, *args, input_var: bool = True, **kwargs,
+ self,
+ *args,
+ input_var: bool = True,
+ noise: Union[float, ArrayType, Initializer, Callable] = None,
+ **kwargs,
):
self.input_var = input_var
super().__init__(*args, **kwargs, init_var=False)
+
+ self.noise = init_noise(noise, self.varshape, num_vars=4)
+ if self.noise is not None:
+ self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)
def reset_state(self, batch_size=None):
@@ -298,10 +306,17 @@ class MorrisLecar(hh.MorrisLecar):
"""
def __init__(
- self, *args, input_var: bool = True, **kwargs,
+ self,
+ *args,
+ input_var: bool = True,
+ noise: Union[float, ArrayType, Initializer, Callable] = None,
+ **kwargs,
):
self.input_var = input_var
super().__init__(*args, **kwargs, init_var=False)
+ self.noise = init_noise(noise, self.varshape, num_vars=2)
+ if self.noise is not None:
+ self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)
def reset_state(self, batch_size=None):
@@ -797,16 +812,23 @@ class WangBuzsakiModel(hh.WangBuzsakiHH):
"""
def __init__(
- self, *args, input_var: bool = True, **kwargs,
+ self,
+ *args,
+ input_var: bool = True,
+ noise: Union[float, ArrayType, Initializer, Callable] = None,
+ **kwargs,
):
self.input_var = input_var
super().__init__(*args, **kwargs, init_var=False)
+ self.noise = init_noise(noise, self.varshape, num_vars=3)
+ if self.noise is not None:
+ self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)
def reset_state(self, batch_size=None):
super().reset_state(batch_size)
if self.input_var:
- self.input.value = variable_(bm.zeros, self.varshape, batch_size)
+ self.input = variable_(bm.zeros, self.varshape, batch_size)
def update(self, x=None):
if self.input_var:
diff --git a/brainpy/_src/dynold/neurons/fractional_models.py b/brainpy/_src/dynold/neurons/fractional_models.py
index 09babeb78..93afd0807 100644
--- a/brainpy/_src/dynold/neurons/fractional_models.py
+++ b/brainpy/_src/dynold/neurons/fractional_models.py
@@ -6,7 +6,7 @@
import brainpy.math as bm
from brainpy._src.context import share
-from brainpy._src.dynsys import NeuDyn
+from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import ZeroInit, OneInit, Initializer, parameter
from brainpy._src.integrators.fde import CaputoL1Schema
from brainpy._src.integrators.fde import GLShortMemory
diff --git a/brainpy/_src/dynold/neurons/reduced_models.py b/brainpy/_src/dynold/neurons/reduced_models.py
index a0c42141d..06784d5de 100644
--- a/brainpy/_src/dynold/neurons/reduced_models.py
+++ b/brainpy/_src/dynold/neurons/reduced_models.py
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
-from typing import Union, Callable
+from typing import Union, Callable, Optional
from jax.lax import stop_gradient
import brainpy.math as bm
from brainpy._src.context import share
from brainpy._src.dyn.neurons import lif
-from brainpy._src.dynsys import NeuDyn
+from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import (ZeroInit,
OneInit,
Initializer,
@@ -196,10 +196,17 @@ class LIF(lif.LifRef):
"""
def __init__(
- self, *args, input_var: bool = True, **kwargs,
+ self,
+ *args,
+ input_var: bool = True,
+ noise: Optional[Union[float, ArrayType, Initializer, Callable]] = None,
+ **kwargs,
):
self.input_var = input_var
super().__init__(*args, **kwargs, init_var=False)
+ self.noise = init_noise(noise, self.varshape)
+ if self.noise is not None:
+ self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)
def reset_state(self, batch_size=None):
@@ -320,10 +327,17 @@ class ExpIF(lif.ExpIFRef):
"""
def __init__(
- self, *args, input_var: bool = True, **kwargs,
+ self,
+ *args,
+ input_var: bool = True,
+ noise: Union[float, ArrayType, Initializer, Callable] = None,
+ **kwargs,
):
self.input_var = input_var
super().__init__(*args, **kwargs, init_var=False)
+ self.noise = init_noise(noise, self.varshape)
+ if self.noise is not None:
+ self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)
def reset_state(self, batch_size=None):
@@ -421,10 +435,17 @@ class AdExIF(lif.AdExIFRef):
"""
def __init__(
- self, *args, input_var: bool = True, **kwargs,
+ self,
+ *args,
+ input_var: bool = True,
+ noise: Optional[Union[float, ArrayType, Initializer, Callable]] = None,
+ **kwargs,
):
self.input_var = input_var
super().__init__(*args, **kwargs, init_var=False)
+ self.noise = init_noise(noise, self.varshape, num_vars=2)
+ if self.noise is not None:
+ self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)
def reset_state(self, batch_size=None):
@@ -514,10 +535,17 @@ class QuaIF(lif.QuaIFRef):
"""
def __init__(
- self, *args, input_var: bool = True, **kwargs,
+ self,
+ *args,
+ input_var: bool = True,
+ noise: Union[float, ArrayType, Initializer, Callable] = None,
+ **kwargs,
):
self.input_var = input_var
super().__init__(*args, **kwargs, init_var=False)
+ self.noise = init_noise(noise, self.varshape, num_vars=1)
+ if self.noise is not None:
+ self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)
def reset_state(self, batch_size=None):
@@ -617,10 +645,17 @@ class AdQuaIF(lif.AdQuaIFRef):
"""
def __init__(
- self, *args, input_var: bool = True, **kwargs,
+ self,
+ *args,
+ input_var: bool = True,
+ noise: Union[float, ArrayType, Initializer, Callable] = None,
+ **kwargs,
):
self.input_var = input_var
super().__init__(*args, **kwargs, init_var=False)
+ self.noise = init_noise(noise, self.varshape, num_vars=2)
+ if self.noise is not None:
+ self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)
def reset_state(self, batch_size=None):
@@ -725,10 +760,17 @@ class GIF(lif.GifRef):
"""
def __init__(
- self, *args, input_var: bool = True, **kwargs,
+ self,
+ *args,
+ input_var: bool = True,
+ noise: Union[float, ArrayType, Initializer, Callable] = None,
+ **kwargs,
):
self.input_var = input_var
super().__init__(*args, **kwargs, init_var=False)
+ self.noise = init_noise(noise, self.varshape, num_vars=4)
+ if self.noise is not None:
+ self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)
def reset_state(self, batch_size=None):
@@ -819,10 +861,17 @@ class Izhikevich(lif.IzhikevichRef):
"""
def __init__(
- self, *args, input_var: bool = True, **kwargs,
+ self,
+ *args,
+ input_var: bool = True,
+ noise: Union[float, ArrayType, Initializer, Callable] = None,
+ **kwargs,
):
self.input_var = input_var
super().__init__(*args, **kwargs, init_var=False)
+ self.noise = init_noise(noise, self.varshape, num_vars=2)
+ if self.noise is not None:
+ self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise)
self.reset_state(self.mode)
def reset_state(self, batch_size=None):
diff --git a/brainpy/_src/dynold/synapses/abstract_models.py b/brainpy/_src/dynold/synapses/abstract_models.py
index 8366bbe9c..bc50f8c4c 100644
--- a/brainpy/_src/dynold/synapses/abstract_models.py
+++ b/brainpy/_src/dynold/synapses/abstract_models.py
@@ -8,7 +8,7 @@
from brainpy._src.connect import TwoEndConnector, All2All, One2One
from brainpy._src.dyn import synapses
from brainpy._src.dynold.synouts import MgBlock, CUBA
-from brainpy._src.dynsys import NeuDyn
+from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import Initializer
from brainpy._src.mixin import AlignPost
from brainpy.types import ArrayType
@@ -293,8 +293,8 @@ def __init__(
if bm.size(self.tau) != 1:
raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}')
- syn = synapses.Expon.desc(pre.size,
- pre.keep_size,
+ syn = synapses.Expon.desc(post.size,
+ post.keep_size,
mode=mode,
tau=tau,
method=method)
diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py
index b36b40c9b..bf14cbae0 100644
--- a/brainpy/_src/dynold/synapses/base.py
+++ b/brainpy/_src/dynold/synapses/base.py
@@ -7,7 +7,8 @@
from brainpy._src.connect import TwoEndConnector, MatConn, IJConn, One2One, All2All
from brainpy._src.dnn import linear
from brainpy._src.dyn import projections
-from brainpy._src.dynsys import Projection, DynamicalSystem, NeuDyn, Sequential
+from brainpy._src.dynsys import Projection, DynamicalSystem
+from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import parameter
from brainpy._src.mixin import (ParamDesc, ParamDescInit, JointType,
AutoDelaySupp, BindCondData, AlignPost,
@@ -445,7 +446,8 @@ def __init__(
raise UnsupportedError(f'Does not support {comp_method}, only "sparse" or "dense".')
self.proj = proj
self.proj.post.cur_inputs.pop(self.proj.name)
- self.stp = self.pre.after_updates[self.proj._syn_id].syn.stp
+ if hasattr(self.pre.after_updates[self.proj._syn_id].syn, 'stp'):
+ self.stp = self.pre.after_updates[self.proj._syn_id].syn.stp
def update(self, pre_spike=None, stop_spike_gradient: bool = False):
if pre_spike is None:
diff --git a/brainpy/_src/dynold/synapses/biological_models.py b/brainpy/_src/dynold/synapses/biological_models.py
index 861db52e9..bdd04b2b5 100644
--- a/brainpy/_src/dynold/synapses/biological_models.py
+++ b/brainpy/_src/dynold/synapses/biological_models.py
@@ -8,7 +8,7 @@
from brainpy._src.dynold.synapses import _SynSTP, _SynOut, _TwoEndConnAlignPre
from brainpy._src.dynold.synapses.base import _init_stp, _DelayedSyn
from brainpy._src.dynold.synouts import COBA, MgBlock
-from brainpy._src.dynsys import NeuDyn
+from brainpy._src.dyn.base import NeuDyn
from brainpy.types import ArrayType
__all__ = [
diff --git a/brainpy/_src/dynold/synapses/compat.py b/brainpy/_src/dynold/synapses/compat.py
index e4b9483bb..108f01ad5 100644
--- a/brainpy/_src/dynold/synapses/compat.py
+++ b/brainpy/_src/dynold/synapses/compat.py
@@ -5,7 +5,7 @@
from brainpy._src.connect import TwoEndConnector
from brainpy._src.dynold.synouts import COBA, CUBA
-from brainpy._src.dynsys import NeuDyn
+from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import Initializer
from brainpy.types import ArrayType
from .abstract_models import Delta, Exponential, DualExponential
diff --git a/brainpy/_src/dynold/synapses/learning_rules.py b/brainpy/_src/dynold/synapses/learning_rules.py
index 583a2c01b..164803133 100644
--- a/brainpy/_src/dynold/synapses/learning_rules.py
+++ b/brainpy/_src/dynold/synapses/learning_rules.py
@@ -6,7 +6,8 @@
from brainpy._src.dyn import synapses
from brainpy._src.dynold.synouts import CUBA
from brainpy._src.dynold.synapses import _TwoEndConnAlignPre
-from brainpy._src.dynsys import NeuDyn, Sequential
+from brainpy._src.dynsys import Sequential
+from brainpy._src.dyn.base import NeuDyn
from brainpy._src.initialize import Initializer
from brainpy._src.mixin import ParamDesc
from brainpy.types import ArrayType
diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py
index 5465d1898..131ad925a 100644
--- a/brainpy/_src/dynsys.py
+++ b/brainpy/_src/dynsys.py
@@ -1,15 +1,16 @@
# -*- coding: utf-8 -*-
-import collections
import gc
import inspect
from typing import Union, Dict, Callable, Sequence, Optional, Tuple, Any
+import collections
+import jax
import numpy as np
from brainpy import tools, math as bm
from brainpy._src.initialize import parameter, variable_
-from brainpy._src.mixin import AutoDelaySupp, ParamDesc, Container, DelayRegister, global_delay_data
+from brainpy._src.mixin import AutoDelaySupp, Container, DelayRegister, global_delay_data
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape
@@ -22,8 +23,8 @@
# containers
'DynSysGroup', 'Network', 'Sequential',
- # base classes
- 'NeuDyn', 'SynDyn', 'IonChaDyn',
+ # category
+ 'Dynamics', 'Projection',
]
SLICE_VARS = 'slice_vars'
@@ -79,16 +80,16 @@ class DynamicalSystem(bm.BrainPyObject, DelayRegister):
If users want to define the logic of running models across multiple steps,
we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`,
:py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`.
-
+
To be compatible with previous APIs, :py:class:`~.DynamicalSystem` inherits
- from the :py:class:`~.DelayRegister`. It's worthy to note that the methods of
- :py:class:`~.DelayRegister` will be removed in the future, including:
-
+ from the :py:class:`~.DelayRegister`. It's worthy to note that the methods of
+ :py:class:`~.DelayRegister` will be removed in the future, including:
+
- ``.register_delay()``
- ``.get_delay_data()``
- ``.update_local_delays()``
- ``.reset_local_delays()``
-
+
Parameters
----------
name : optional, str
@@ -507,9 +508,6 @@ def __repr__(self):
return f'{self.__class__.__name__}(\n{entries}\n)'
-
-
-
class Projection(DynamicalSystem):
def reset_state(self, *args, **kwargs):
pass
@@ -626,22 +624,7 @@ def __repr__(self):
return f'{self.__class__.__name__}(name={self.name}, mode={self.mode}, size={self.size})'
def __getitem__(self, item):
- return NeuDynView(target=self, index=item)
-
-
-class NeuDyn(Dynamics, AutoDelaySupp):
- """Neuronal Dynamics."""
- pass
-
-
-class SynDyn(Dynamics, AutoDelaySupp, ParamDesc):
- """Synaptic Dynamics."""
- pass
-
-
-class IonChaDyn(Dynamics):
- """Ion Channel Dynamics."""
- pass
+ return DynView(target=self, index=item)
class DynView(Dynamics):
@@ -661,50 +644,42 @@ def __init__(
self,
target: Dynamics,
index: Union[slice, Sequence, ArrayType],
- varshape: Tuple[int, ...] = None,
- name: str = None,
- mode: bm.Mode = None
+ name: Optional[str] = None,
):
- # initialization
- DynamicalSystem.__init__(self, name=name, mode=mode)
-
# check target
- if not isinstance(target, DynamicalSystem):
- raise TypeError(f'Should be instance of DynamicalSystem, but we got {type(target)}.')
+ if not isinstance(target, Dynamics):
+ raise TypeError(f'Should be instance of {Dynamics.__name__}, but we got {type(target)}.')
self.target = target # the target object to slice
# check slicing
if isinstance(index, (int, slice)):
index = (index,)
self.index = index # the slice
+ if len(self.index) > len(target.varshape):
+ raise ValueError(f"Length of the index should be less than "
+ f"that of the target's varshape. But we "
+ f"got {len(self.index)} > {len(target.varshape)}")
# get all variables for slicing
- if not hasattr(self.target, SLICE_VARS):
- if varshape is None:
- if isinstance(target, NeuDyn):
- varshape = target.varshape
- else:
- raise UnsupportedError('Should provide varshape when the target does '
- f'not define its {SLICE_VARS}')
- all_vars = target.vars(level=1, include_self=True, method='relative')
- all_vars = {k: v for k, v in all_vars.items()} # TODO
- # all_vars = {k: v for k, v in all_vars.items() if v.nobatch_shape == varshape}
- else:
+ if hasattr(self.target, SLICE_VARS):
all_vars = {}
for var_str in getattr(self.target, SLICE_VARS):
v = eval(f'target.{var_str}')
all_vars[var_str] = v
+ else:
+ all_vars = target.vars(level=1, include_self=True, method='relative')
+ all_vars = {k: v for k, v in all_vars.items()} # TODO
+ # all_vars = {k: v for k, v in all_vars.items() if v.nobatch_shape == varshape}
# slice variables
self.slice_vars = dict()
for k, v in all_vars.items():
if v.batch_axis is not None:
- index = ((self.index[:v.batch_axis] +
- (slice(None, None, None),) +
- self.index[v.batch_axis:])
- if len(self.index) > v.batch_axis else
- (self.index + tuple([slice(None, None, None)
- for _ in range(v.batch_axis - len(self.index) + 1)])))
+ index = (
+ (self.index[:v.batch_axis] + (slice(None, None, None),) + self.index[v.batch_axis:])
+ if (len(self.index) > v.batch_axis) else
+ (self.index + tuple([slice(None, None, None) for _ in range(v.batch_axis - len(self.index) + 1)]))
+ )
else:
index = self.index
self.slice_vars[k] = bm.VariableView(v, index)
@@ -712,14 +687,32 @@ def __init__(
# sub-nodes
nodes = target.nodes(method='relative', level=1, include_self=False).subset(DynamicalSystem)
for k, node in nodes.items():
- if isinstance(node, NeuDyn):
- node = NeuDynView(node, self.index)
+ if isinstance(node, Dynamics):
+ node = DynView(node, self.index)
else:
- node = DynView(node, self.index, varshape)
+ node = DynView(node, self.index)
setattr(self, k, node)
+ # initialization
+ # get size
+ size = []
+ for i, idx in enumerate(self.index):
+ if isinstance(idx, int):
+ size.append(1)
+ elif isinstance(idx, slice):
+ size.append(_slice_to_num(idx, target.varshape[i]))
+ else:
+ # should be a list/tuple/array of int
+ # do not check again
+ if not isinstance(idx, collections.Iterable):
+ raise TypeError('Should be an iterable object of int.')
+ size.append(len(idx))
+ size += list(target.varshape[len(self.index):])
+
+ super().__init__(size, keep_size=target.keep_size, name=name, mode=target.mode)
+
def __repr__(self):
- return f'{self.__class__.__name__}(target={self.target}, index={self.index})'
+ return f'{self.name}(target={self.target}, index={self.index})'
def __getattribute__(self, item):
try:
@@ -733,7 +726,7 @@ def __getattribute__(self, item):
def __setattr__(self, key, value):
if hasattr(self, 'slice_vars'):
- slice_vars = super(DynView, self).__getattribute__('slice_vars')
+ slice_vars = super().__getattribute__('slice_vars')
if key in slice_vars:
v = slice_vars[key]
v.value = value
@@ -741,7 +734,8 @@ def __setattr__(self, key, value):
super(DynView, self).__setattr__(key, value)
def update(self, *args, **kwargs):
- raise NoImplementationError(f'DSView {self} cannot be updated. Please update its parent {self.target}')
+ raise NoImplementationError(f'{DynView.__name__} {self} cannot be updated. '
+ f'Please update its parent {self.target}')
def reset_state(self, batch_size=None):
pass
@@ -773,41 +767,3 @@ def _slice_to_num(slice_: slice, length: int):
start += step
num += 1
return num
-
-
-class NeuDynView(DynView, NeuDyn):
- """A view for a neuron group instance."""
-
- def __init__(
- self,
- target: NeuDyn,
- index: Union[slice, Sequence, ArrayType],
- name: str = None,
- mode: bm.Mode = None
- ):
- DynView.__init__(self, target, index)
-
- # check slicing
- var_shapes = target.varshape
- if len(self.index) > len(var_shapes):
- raise ValueError(f"Length of the index should be less than "
- f"that of the target's varshape. But we "
- f"got {len(self.index)} > {len(var_shapes)}")
-
- # get size
- size = []
- for i, idx in enumerate(self.index):
- if isinstance(idx, int):
- size.append(1)
- elif isinstance(idx, slice):
- size.append(_slice_to_num(idx, var_shapes[i]))
- else:
- # should be a list/tuple/array of int
- # do not check again
- if not isinstance(idx, collections.Iterable):
- raise TypeError('Should be an iterable object of int.')
- size.append(len(idx))
- size += list(var_shapes[len(self.index):])
-
- # initialization
- NeuDyn.__init__(self, tuple(size), name=name, mode=mode)
diff --git a/brainpy/_src/integrators/ode/exponential.py b/brainpy/_src/integrators/ode/exponential.py
index 9d1b1adcf..b2d142c0e 100644
--- a/brainpy/_src/integrators/ode/exponential.py
+++ b/brainpy/_src/integrators/ode/exponential.py
@@ -138,7 +138,7 @@ class ExponentialEuler(ODEIntegrator):
>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
- >>> class HH(bp.NeuDyn):
+ >>> class HH(bp.dyn.NeuDyn):
>>> def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9.,
>>> gL=0.1, V_th=20., phi=5.0, name=None):
>>> super(HH, self).__init__(size=size, name=name)
@@ -211,7 +211,7 @@ class ExponentialEuler(ODEIntegrator):
>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
- >>> class HH(bp.NeuDyn):
+ >>> class HH(bp.dyn.NeuDyn):
>>> def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9.,
>>> gL=0.1, V_th=20., phi=5.0, name=None):
>>> super(HH, self).__init__(size=size, name=name)
diff --git a/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py b/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py
index 46654c4a0..2b8dd6781 100644
--- a/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py
+++ b/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py
@@ -46,7 +46,7 @@ def dev(x, t):
class TestExpEulerAuto(unittest.TestCase):
def test_hh_model(self):
- class HH(bp.NeuDyn):
+ class HH(bp.dyn.NeuDyn):
def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9.,
gL=0.1, V_th=20., phi=5.0, name=None, method='exponential_euler'):
super(HH, self).__init__(size=size, name=name)
diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py
index 259952a6b..5468a4fcb 100644
--- a/brainpy/_src/math/event/tests/test_event_csrmv.py
+++ b/brainpy/_src/math/event/tests/test_event_csrmv.py
@@ -9,12 +9,13 @@
import brainpy as bp
import brainpy.math as bm
+import platform
-import brainpylib as bl
import pytest
-if bl.__version__ < '0.1.9':
- pytest.skip('Need brainpylib>=0.1.9', allow_module_level=True)
+is_manual_test = False
+if platform.system() == 'Windows' and not is_manual_test:
+ pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True)
def sum_op(op):
diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py
index 1af2a3aeb..c8d661233 100644
--- a/brainpy/_src/math/jitconn/_event_matvec.py
+++ b/brainpy/_src/math/jitconn/_event_matvec.py
@@ -10,7 +10,6 @@
from jax.interpreters import xla, ad
from jax.lib import xla_client
-from brainpy._src.math.ndarray import _get_dtype
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.jitconn._matvec import (mv_prob_homo_p,
mv_prob_uniform_p,
@@ -18,8 +17,9 @@
mv_prob_homo,
mv_prob_uniform,
mv_prob_normal)
+from brainpy._src.math.ndarray import _get_dtype
from brainpy._src.math.op_registers import register_general_batching
-from brainpy.errors import GPUOperatorNotFound, MathError
+from brainpy.errors import GPUOperatorNotFound
try:
from brainpylib import gpu_ops
@@ -168,15 +168,7 @@ def _event_matvec_prob_homo_cpu_translation(
c, events, weight, clen, seed, *, shape, transpose, outdim_parallel
):
n_row, n_col = (shape[1], shape[0]) if transpose else shape
- event_shape = c.get_shape(events)
- if event_shape.element_type() == jnp.bool_:
- event_type = b'_bool'
- out_dtype = dtypes.canonicalize_dtype(float)
- type_name = b'_float' if out_dtype == jnp.float32 else b'_double'
- else:
- out_dtype = event_shape.element_type()
- event_type = b'_float' if out_dtype == jnp.float32 else b'_double'
- type_name = event_type
+ out_dtype, event_type, type_name = _get_types(c.get_shape(events))
if outdim_parallel:
fn = b'cpu_event_matvec_prob_homo' + type_name + event_type
@@ -212,15 +204,7 @@ def _event_matvec_prob_homo_gpu_translation(
if gpu_ops is None:
raise GPUOperatorNotFound(event_mv_prob_homo_p.name)
- event_shape = c.get_shape(events)
- if event_shape.element_type() == jnp.bool_:
- event_type = b'_bool'
- out_dtype = dtypes.canonicalize_dtype(float)
- type_name = b'_float' if out_dtype == jnp.float32 else b'_double'
- else:
- out_dtype = event_shape.element_type()
- event_type = b'_float' if out_dtype == jnp.float32 else b'_double'
- type_name = event_type
+ out_dtype, event_type, type_name = _get_types(c.get_shape(events))
opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0],
shape[0] if transpose else shape[1], )
@@ -367,15 +351,7 @@ def _event_matvec_prob_uniform_cpu_translation(
):
n_row, n_col = (shape[1], shape[0]) if transpose else shape
- event_shape = c.get_shape(events)
- if event_shape.element_type() == jnp.bool_:
- event_type = b'_bool'
- out_dtype = dtypes.canonicalize_dtype(float)
- type_name = b'_float' if (out_dtype == jnp.float32) else b'_double'
- else:
- out_dtype = event_shape.element_type()
- event_type = b'_float' if (out_dtype == jnp.float32) else b'_double'
- type_name = event_type
+ out_dtype, event_type, type_name = _get_types(c.get_shape(events))
if outdim_parallel:
fn = b'cpu_event_matvec_prob_uniform' + type_name + event_type
@@ -412,15 +388,7 @@ def _event_matvec_prob_uniform_gpu_translation(
if gpu_ops is None:
raise GPUOperatorNotFound(event_mv_prob_uniform_p.name)
- event_shape = c.get_shape(events)
- if event_shape.element_type() == jnp.bool_:
- event_type = b'_bool'
- out_dtype = dtypes.canonicalize_dtype(float)
- type_name = b'_float' if out_dtype == jnp.float32 else b'_double'
- else:
- out_dtype = event_shape.element_type()
- event_type = b'_float' if out_dtype == jnp.float32 else b'_double'
- type_name = event_type
+ out_dtype, event_type, type_name = _get_types(c.get_shape(events))
opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0],
shape[0] if transpose else shape[1])
@@ -513,7 +481,6 @@ def _event_matvec_prob_normal_abstract(
_w_sigma_dtype = _get_dtype(w_sigma)
assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.'
assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.'
- assert _w_sigma_dtype in [jnp.float32, jnp.float64], '"w_sigma" must be float valued.'
assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]
assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]
@@ -547,20 +514,36 @@ def _event_matvec_prob_normal_abstract(
return [out]
+def _get_types(event_shape):
+ event_type = event_shape.element_type()
+ if event_type == jnp.bool_:
+ event_type = b'_bool'
+ out_dtype = dtypes.canonicalize_dtype(float)
+ elif event_type == jnp.float32:
+ event_type = b'_float'
+ out_dtype = event_shape.element_type()
+ elif event_type == jnp.float64:
+ event_type = b'_double'
+ out_dtype = event_shape.element_type()
+ else:
+ raise TypeError
+
+ if out_dtype == jnp.float32:
+ type_name = b'_float'
+ elif out_dtype == jnp.float64:
+ type_name = b'_double'
+ else:
+ raise TypeError
+
+ return out_dtype, event_type, type_name
+
+
def _event_matvec_prob_normal_cpu_translation(
c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel
):
n_row, n_col = (shape[1], shape[0]) if transpose else shape
- event_shape = c.get_shape(events)
- if event_shape.element_type() == jnp.bool_:
- event_type = b'_bool'
- out_dtype = dtypes.canonicalize_dtype(float)
- type_name = b'_float' if out_dtype == jnp.float32 else b'_double'
- else:
- out_dtype = event_shape.element_type()
- event_type = b'_float' if out_dtype == jnp.float32 else b'_double'
- type_name = event_type
+ out_dtype, event_type, type_name = _get_types(c.get_shape(events))
if outdim_parallel:
fn = b'cpu_event_matvec_prob_normal' + type_name + event_type
@@ -597,15 +580,8 @@ def _event_matvec_prob_normal_gpu_translation(
if gpu_ops is None:
raise GPUOperatorNotFound(event_mv_prob_normal_p.name)
- event_shape = c.get_shape(events)
- if event_shape.element_type() == jnp.bool_:
- event_type = b'_bool'
- out_dtype = dtypes.canonicalize_dtype(float)
- type_name = b'_float' if out_dtype == jnp.float32 else b'_double'
- else:
- out_dtype = event_shape.element_type()
- event_type = b'_float' if out_dtype == jnp.float32 else b'_double'
- type_name = event_type
+ out_dtype, event_type, type_name = _get_types(c.get_shape(events))
+
opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0],
shape[0] if transpose else shape[1])
if outdim_parallel:
diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py
index 7ebeef6c0..f442cbada 100644
--- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py
+++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py
@@ -4,13 +4,14 @@
import jax.numpy as jnp
from absl.testing import parameterized
+import platform
import brainpy.math as bm
-import brainpylib as bl
import pytest
-if bl.__version__ < '0.1.9':
- pytest.skip('Need brainpylib>=0.1.9', allow_module_level=True)
+is_manual_test = False
+if platform.system() == 'Windows' and not is_manual_test:
+ pytest.skip('Under windows, brainpy.math package may need manual tests.', allow_module_level=True)
shapes = [(100, 200),
@@ -26,28 +27,15 @@ def __init__(self, *args, platform='cpu', **kwargs):
bm.set_platform(platform)
print()
- @parameterized.named_parameters(
- dict(testcase_name=f'_test_homo: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'homo_data = {homo_data}, '
- f'bool_event = {bool_event}, '
- f'x64={x64}',
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- homo_data=homo_data,
- bool_event=bool_event, seed=1234, x64=x64)
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1, 0.5]
- for homo_data in [-1., ]
- for bool_event in [True, False]
+ @parameterized.product(
+ transpose=[True, False],
+ x64=[True, False],
+ outdim_parallel=[True, False],
+ shape=shapes,
+ prob=[0.01, 0.1, 0.5],
+ homo_data= [-1., ],
+ bool_event=[True, False],
+ seed = [1234],
)
def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=None, x64=False):
print(f'_test_homo: '
@@ -73,6 +61,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve
seed=seed,
outdim_parallel=outdim_parallel,
transpose=transpose)
+ r1 = jax.block_until_ready(r1)
r2 = bm.jitconn.event_mv_prob_homo(events,
homo_data,
@@ -81,6 +70,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve
seed=seed,
outdim_parallel=outdim_parallel,
transpose=transpose)
+ r2 = jax.block_until_ready(r2)
self.assertTrue(jnp.allclose(r1, r2))
r3 = bm.jitconn.event_mv_prob_homo(events,
@@ -90,6 +80,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve
seed=seed,
outdim_parallel=outdim_parallel,
transpose=not transpose)
+ r3 = jax.block_until_ready(r3)
self.assertTrue(jnp.allclose(r1, r3))
# indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post')
@@ -103,27 +94,16 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve
bm.disable_x64()
bm.clear_buffer_memory()
- @parameterized.named_parameters(
- dict(testcase_name=f'_test_homo_vmap: '
- f'shape = {shape}, '
- f'transpose = {transpose}, '
- f'outdim_parallel = {outdim_parallel}, '
- f'prob={prob}, '
- f'bool_event = {bool_event}, x64={x64}',
- shape=shape,
- transpose=transpose,
- outdim_parallel=outdim_parallel,
- prob=prob,
- x64=x64,
- bool_event=bool_event,
- seed=1234)
- for transpose in [True, False]
- for x64 in [True, False]
- for outdim_parallel in [True, False]
- for shape in shapes
- for prob in [0.01, 0.1, 0.5]
- for bool_event in [True, False]
+ @parameterized.product(
+ transpose=[True, False],
+
+ x64= [True, False],
+ outdim_parallel= [True, False],
+ shape= shapes,
+ prob= [0.01, 0.1, 0.5],
+ bool_event= [True, False],
+ seed = [1234],
)
def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=None, x64=False):
print(f'_test_homo_vmap: '
@@ -149,7 +129,9 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=Tru
)
)
r1 = f1(events, weights)
+ r1 = jax.block_until_ready(r1)
r2 = f1(events, weights)
+ r2 = jax.block_until_ready(r2)
self.assertTrue(jnp.allclose(r1, r2))
if x64:
bm.disable_x64()
@@ -192,10 +174,13 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64
argnums=0
)
r1 = f1(events, 1.)
+ r1 = jax.block_until_ready(r1)
r2 = f1(events, 2.)
+ r2 = jax.block_until_ready(r2)
r3 = f1(events, 3.)
+ r3 = jax.block_until_ready(r3)
self.assertTrue(jnp.allclose(r1 * 3., r3))
self.assertTrue(jnp.allclose(r1 * 2., r2))
@@ -257,6 +242,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high,
seed=seed,
outdim_parallel=outdim_parallel,
transpose=transpose)
+ r1 = jax.block_until_ready(r1)
r2 = bm.jitconn.event_mv_prob_uniform(events,
w_low=w_low,
@@ -266,6 +252,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high,
seed=seed,
outdim_parallel=outdim_parallel,
transpose=transpose)
+ r2 = jax.block_until_ready(r2)
self.assertTrue(jnp.allclose(r1, r2))
r3 = bm.jitconn.event_mv_prob_uniform(events,
@@ -276,6 +263,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high,
seed=seed,
outdim_parallel=outdim_parallel,
transpose=not transpose)
+ r3 = jax.block_until_ready(r3)
self.assertTrue(jnp.allclose(r1, r3))
if x64:
bm.disable_x64()
@@ -328,7 +316,9 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob,
)
r1 = f1(events)
+ r1 = jax.block_until_ready(r1)
r2 = f1(events)
+ r2 = jax.block_until_ready(r2)
self.assertTrue(jnp.allclose(r1, r2))
if x64:
bm.disable_x64()
@@ -377,7 +367,9 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None,
)
r1 = f1(events, 1.)
+ r1 = jax.block_until_ready(r1)
r2 = f1(events, 2.)
+ r2 = jax.block_until_ready(r2)
self.assertTrue(bm.allclose(r1 * 2., r2))
# print(r1)
if x64:
@@ -432,6 +424,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma,
seed=seed,
outdim_parallel=outdim_parallel,
transpose=transpose)
+ r1 = jax.block_until_ready(r1)
r2 = bm.jitconn.event_mv_prob_normal(events,
w_mu=w_mu,
@@ -441,6 +434,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma,
seed=seed,
outdim_parallel=outdim_parallel,
transpose=transpose)
+ r2 = jax.block_until_ready(r2)
self.assertTrue(jnp.allclose(r1, r2))
r3 = bm.jitconn.event_mv_prob_normal(events,
@@ -451,6 +445,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma,
seed=seed,
outdim_parallel=outdim_parallel,
transpose=not transpose)
+ r3 = jax.block_until_ready(r3)
self.assertTrue(jnp.allclose(r1, r3))
if x64:
@@ -503,7 +498,9 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob,
outdim_parallel=outdim_parallel,
transpose=transpose))
r1 = f1(events)
+ r1 = jax.block_until_ready(r1)
r2 = f1(events)
+ r2 = jax.block_until_ready(r2)
self.assertTrue(jnp.allclose(r1, r2))
if x64:
bm.disable_x64()
@@ -540,19 +537,23 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x
events = bm.as_jax(events)
events = events.astype(float)
- f1 = jax.grad(
- lambda e, w_sigma: bm.jitconn.event_mv_prob_normal(
- e,
- w_mu=0.,
- w_sigma=w_sigma,
- conn_prob=prob,
- shape=shape,
- seed=seed,
- outdim_parallel=outdim_parallel,
- transpose=transpose).sum()
+ f1 = jax.jit(
+ jax.grad(
+ lambda e, w_sigma: bm.jitconn.event_mv_prob_normal(
+ e,
+ w_mu=0.,
+ w_sigma=w_sigma,
+ conn_prob=prob,
+ shape=shape,
+ seed=seed,
+ outdim_parallel=outdim_parallel,
+ transpose=transpose).sum()
+ )
)
r1 = f1(events, 1.)
+ r1 = jax.block_until_ready(r1)
r2 = f1(events, 2.)
+ r2 = jax.block_until_ready(r2)
self.assertTrue(bm.allclose(r1 * 2, r2))
if x64:
bm.disable_x64()
diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py
index e202d39d6..91c48fc66 100644
--- a/brainpy/_src/math/jitconn/tests/test_matvec.py
+++ b/brainpy/_src/math/jitconn/tests/test_matvec.py
@@ -5,12 +5,12 @@
from absl.testing import parameterized
import brainpy.math as bm
-
-import brainpylib as bl
+import platform
import pytest
-if bl.__version__ < '0.1.9':
- pytest.skip('Need brainpylib>=0.1.9', allow_module_level=True)
+is_manual_test = False
+if platform.system() == 'Windows' and not is_manual_test:
+ pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True)
shapes = [(100, 200),
(10, 1000),
diff --git a/brainpy/_src/math/object_transform/tests/test_circular_reference.py b/brainpy/_src/math/object_transform/tests/test_circular_reference.py
index 2dc076ff4..61606d36e 100644
--- a/brainpy/_src/math/object_transform/tests/test_circular_reference.py
+++ b/brainpy/_src/math/object_transform/tests/test_circular_reference.py
@@ -5,7 +5,7 @@
import brainpy as bp
-class HH(bp.NeuDyn):
+class HH(bp.dyn.NeuDyn):
def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0,
gNa=35., gK=9., gL=0.1, V_th=20., phi=5.0, **kwargs):
super(HH, self).__init__(size=size, **kwargs)
diff --git a/brainpy/_src/math/object_transform/tests/test_collector.py b/brainpy/_src/math/object_transform/tests/test_collector.py
index f5b7fb0d0..9c3d5dde6 100644
--- a/brainpy/_src/math/object_transform/tests/test_collector.py
+++ b/brainpy/_src/math/object_transform/tests/test_collector.py
@@ -40,7 +40,7 @@ def update(self, tdi):
self.post.inputs -= jnp.sum(self.s, axis=0) * (self.post.V - self.E)
-class HH_without_Variable(bp.NeuDyn):
+class HH_without_Variable(bp.dyn.NeuDyn):
def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0,
gNa=35., gK=9., gL=0.1, V_th=20., phi=5.0, **kwargs):
super(HH_without_Variable, self).__init__(size=size, **kwargs)
@@ -117,7 +117,7 @@ def test_neu_vars_1():
assert len(vars) == 0
-class HH_with_Variable(bp.NeuDyn):
+class HH_with_Variable(bp.dyn.NeuDyn):
def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0,
gNa=35., gK=9., gL=0.1, V_th=20., phi=5.0, **kwargs):
super(HH_with_Variable, self).__init__(size=size, **kwargs)
diff --git a/brainpy/_src/math/object_transform/tests/test_namechecking.py b/brainpy/_src/math/object_transform/tests/test_namechecking.py
index c008cd4a9..70b60cbb3 100644
--- a/brainpy/_src/math/object_transform/tests/test_namechecking.py
+++ b/brainpy/_src/math/object_transform/tests/test_namechecking.py
@@ -4,7 +4,7 @@
import brainpy as bp
-class LIF(bp.NeuDyn):
+class LIF(bp.dyn.NeuDyn):
pass
diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py
index 8b193ba78..3a550ac64 100644
--- a/brainpy/_src/math/sparse/tests/test_csrmv.py
+++ b/brainpy/_src/math/sparse/tests/test_csrmv.py
@@ -2,17 +2,17 @@
from functools import partial
-import brainpylib as bl
import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized
-
+import platform
import brainpy as bp
import brainpy.math as bm
-if bl.__version__ < '0.1.9':
- pytest.skip('Need brainpylib>=0.1.9', allow_module_level=True)
+is_manual_test = False
+if platform.system() == 'Windows' and not is_manual_test:
+ pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True)
cusparse_csr_matvec = partial(bm.sparse.csrmv, method='cusparse')
scalar_csr_matvec = partial(bm.sparse.csrmv, method='scalar')
diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py
index 0718b06e4..5fed869ff 100644
--- a/brainpy/_src/mixin.py
+++ b/brainpy/_src/mixin.py
@@ -8,7 +8,8 @@
from brainpy import math as bm, tools
from brainpy._src.initialize import parameter
-from brainpy._src.typing_copy import _SpecialForm, _UnionGenericAlias, _type_check, _remove_dups_flatten
+from brainpy._src.python_typing_copied import (_SpecialForm, _UnionGenericAlias,
+ _type_check, _remove_dups_flatten)
from brainpy.types import ArrayType
DynamicalSystem = None
@@ -405,6 +406,9 @@ def reset_local_delays(self, nodes: Union[Sequence, Dict] = None):
target = global_delay_data[name][1]
delay.reset(target.value)
+ def get_delay_var(self, name):
+ return global_delay_data[name]
+
class BindCondData(MixIn):
"""Bind temporary conductance data.
diff --git a/brainpy/_src/typing_copy.py b/brainpy/_src/python_typing_copied.py
similarity index 100%
rename from brainpy/_src/typing_copy.py
rename to brainpy/_src/python_typing_copied.py
diff --git a/brainpy/_src/tests/test_access_methods.py b/brainpy/_src/tests/test_access_methods.py
index 1e361ffbd..6d2109cbd 100644
--- a/brainpy/_src/tests/test_access_methods.py
+++ b/brainpy/_src/tests/test_access_methods.py
@@ -6,7 +6,7 @@
bp.ode.set_default_odeint('rk4')
-class GABAa(bp.TwoEndConn):
+class GABAa(bp.synapses.TwoEndConn):
def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75.,
alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs):
super(GABAa, self).__init__(pre=pre, post=post, conn=conn, **kwargs)
diff --git a/brainpy/_src/tests/test_dyn_runner.py b/brainpy/_src/tests/test_dyn_runner.py
index 169d12824..0cc2bb90c 100644
--- a/brainpy/_src/tests/test_dyn_runner.py
+++ b/brainpy/_src/tests/test_dyn_runner.py
@@ -73,8 +73,7 @@ def __init__(self, scale=1.0, method='exp_auto'):
# without JIT
runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike},
- inputs=[(net.E.input, 20.), (net.I.input, 20.)],
- jit=False).run(0.2)
+ inputs=[(net.E.input, 20.), (net.I.input, 20.)], jit=False).run(0.2)
diff --git a/brainpy/_src/tests/test_mixin.py b/brainpy/_src/tests/test_mixin.py
index fa9a43177..1544a1f33 100644
--- a/brainpy/_src/tests/test_mixin.py
+++ b/brainpy/_src/tests/test_mixin.py
@@ -6,13 +6,13 @@
class TestParamDesc(unittest.TestCase):
def test1(self):
a = bp.dyn.Expon(1)
- self.assertTrue(not isinstance(a, bp.mixin.ParamDesc[bp.dyn.Expon]))
- self.assertTrue(not isinstance(a, bp.mixin.ParamDesc[bp.DynamicalSystem]))
+ self.assertTrue(not isinstance(a, bp.mixin.ParamDescInit[bp.dyn.Expon]))
+ self.assertTrue(not isinstance(a, bp.mixin.ParamDescInit[bp.DynamicalSystem]))
def test2(self):
a = bp.dyn.Expon.desc(1)
- self.assertTrue(isinstance(a, bp.mixin.ParamDesc[bp.dyn.Expon]))
- self.assertTrue(isinstance(a, bp.mixin.ParamDesc[bp.DynamicalSystem]))
+ self.assertTrue(isinstance(a, bp.mixin.ParamDescInit[bp.dyn.Expon]))
+ self.assertTrue(isinstance(a, bp.mixin.ParamDescInit[bp.DynamicalSystem]))
class TestJointType(unittest.TestCase):
@@ -25,6 +25,6 @@ def test1(self):
def test2(self):
T = bp.mixin.JointType[bp.DynamicalSystem, bp.mixin.ParamDesc]
- self.assertTrue(not isinstance(bp.dyn.Expon(1), bp.mixin.ParamDesc[T]))
- self.assertTrue(isinstance(bp.dyn.Expon.desc(1), bp.mixin.ParamDesc[T]))
+ self.assertTrue(not isinstance(bp.dyn.Expon(1), bp.mixin.ParamDescInit[T]))
+ self.assertTrue(isinstance(bp.dyn.Expon.desc(1), bp.mixin.ParamDescInit[T]))
diff --git a/brainpy/_src/tests/test_slice_view.py b/brainpy/_src/tests/test_slice_view.py
index a952528fb..1383c1a6c 100644
--- a/brainpy/_src/tests/test_slice_view.py
+++ b/brainpy/_src/tests/test_slice_view.py
@@ -45,7 +45,3 @@ def test_lif_train_mode(self):
print('After modification 2: ')
print(lif.V)
-
-
-
-
diff --git a/brainpy/dyn/__init__.py b/brainpy/dyn/__init__.py
index 6471e011d..b3272e45a 100644
--- a/brainpy/dyn/__init__.py
+++ b/brainpy/dyn/__init__.py
@@ -1,4 +1,5 @@
+from .base import *
from .ions import *
from .channels import *
from .neurons import *
diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py
new file mode 100644
index 000000000..5d94717c4
--- /dev/null
+++ b/brainpy/dyn/base.py
@@ -0,0 +1,7 @@
+
+from brainpy._src.dyn.base import (
+ Dynamics,
+ NeuDyn,
+ SynDyn,
+ IonChaDyn,
+)
diff --git a/brainpy/dyn/channels.py b/brainpy/dyn/channels.py
index df5bdd927..11809476a 100644
--- a/brainpy/dyn/channels.py
+++ b/brainpy/dyn/channels.py
@@ -8,6 +8,7 @@
ICaT_HM1992,
ICaT_HP1992,
ICaHT_HM1992,
+ ICaHT_Re1993,
ICaL_IS2008,
)
diff --git a/brainpy/dyn/rates.py b/brainpy/dyn/rates.py
index e69de29bb..3b18ea24e 100644
--- a/brainpy/dyn/rates.py
+++ b/brainpy/dyn/rates.py
@@ -0,0 +1,8 @@
+from brainpy._src.dyn.rates import (
+ FHN,
+ FeedbackFHN,
+ QIF,
+ StuartLandauOscillator,
+ WilsonCowanModel,
+ ThresholdLinearModel,
+)
diff --git a/brainpy/dyn/synapses.py b/brainpy/dyn/synapses.py
index e59a33826..77ab86632 100644
--- a/brainpy/dyn/synapses.py
+++ b/brainpy/dyn/synapses.py
@@ -3,5 +3,19 @@
Delta,
Expon,
DualExpon,
+ NMDA,
+ STD,
+ STP,
)
+from brainpy._src.dyn.synapses.bio_models import (
+ AMPA,
+ GABAa,
+ BioNMDA,
+)
+from brainpy._src.dyn.synapses.delay_couplings import (
+ DiffusiveCoupling,
+ AdditiveCoupling,
+)
+
+
diff --git a/brainpy/mixin.py b/brainpy/mixin.py
index 61bd0dca4..854009283 100644
--- a/brainpy/mixin.py
+++ b/brainpy/mixin.py
@@ -4,6 +4,7 @@
AlignPost as AlignPost,
AutoDelaySupp as AutoDelaySupp,
ParamDesc as ParamDesc,
+ ParamDescInit as ParamDescInit,
NoSH as NoSH,
Container as Container,
TreeNode as TreeNode,
diff --git a/brainpy/rates.py b/brainpy/rates.py
index faaaf799c..10f7e4873 100644
--- a/brainpy/rates.py
+++ b/brainpy/rates.py
@@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-
+from .dyn.rates import *
+
diff --git a/brainpy/synapses.py b/brainpy/synapses.py
index 1d1b6364f..d07fb1954 100644
--- a/brainpy/synapses.py
+++ b/brainpy/synapses.py
@@ -30,4 +30,9 @@
from brainpy._src.dynold.synapses.learning_rules import (
STP as STP,
)
+from brainpy._src.dyn.synapses.delay_couplings import (
+ DiffusiveCoupling,
+ AdditiveCoupling,
+)
+
From 1a64014d81f6dbbdcf9aea645ad24af2e726550d Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sun, 9 Jul 2023 15:35:52 +0800
Subject: [PATCH 017/326] fix typing
---
brainpy/_src/dnn/__init__.py | 1 -
brainpy/_src/dnn/activations.py | 2 +-
brainpy/_src/dnn/conv.py | 2 +-
brainpy/_src/dnn/dropout.py | 2 +-
brainpy/_src/dnn/function.py | 2 +-
brainpy/_src/dnn/interoperation_flax.py | 2 +-
brainpy/_src/dnn/linear.py | 2 +-
brainpy/_src/dnn/normalization.py | 2 +-
brainpy/_src/dnn/nvar.py | 2 +-
brainpy/_src/dnn/pooling.py | 2 +-
brainpy/_src/dnn/reservoir.py | 2 +-
brainpy/_src/dnn/rnncells.py | 2 +-
brainpy/_src/{dnn/base.py => layer.py} | 0
brainpy/_src/losses/base.py | 2 +-
brainpy/_src/mixin.py | 127 +-
brainpy/_src/python_typing_copied.py | 2273 -----------------------
brainpy/dnn/others.py | 4 -
brainpy/neurons.py | 9 +
tests/simulation/test_neu_HH.py | 2 +-
19 files changed, 115 insertions(+), 2325 deletions(-)
rename brainpy/_src/{dnn/base.py => layer.py} (100%)
delete mode 100644 brainpy/_src/python_typing_copied.py
diff --git a/brainpy/_src/dnn/__init__.py b/brainpy/_src/dnn/__init__.py
index f4b5f62c0..6fa1eb184 100644
--- a/brainpy/_src/dnn/__init__.py
+++ b/brainpy/_src/dnn/__init__.py
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
-from .base import *
from .activations import *
from .dropout import *
from .nvar import *
diff --git a/brainpy/_src/dnn/activations.py b/brainpy/_src/dnn/activations.py
index e7461b016..a1bef95e0 100644
--- a/brainpy/_src/dnn/activations.py
+++ b/brainpy/_src/dnn/activations.py
@@ -2,7 +2,7 @@
from brainpy import math as bm
from brainpy.types import ArrayType
-from .base import Layer
+from brainpy._src.layer import Layer
__all__ = [
'Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh',
diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py
index 4d3fe8366..f5e4a1e60 100644
--- a/brainpy/_src/dnn/conv.py
+++ b/brainpy/_src/dnn/conv.py
@@ -7,7 +7,7 @@
from brainpy import math as bm, tools
from brainpy._src.initialize import Initializer, XavierNormal, ZeroInit, parameter
from brainpy.types import ArrayType
-from .base import Layer
+from brainpy._src.layer import Layer
__all__ = [
'Conv1d', 'Conv2d', 'Conv3d',
diff --git a/brainpy/_src/dnn/dropout.py b/brainpy/_src/dnn/dropout.py
index dd60cc1df..184a46aa5 100644
--- a/brainpy/_src/dnn/dropout.py
+++ b/brainpy/_src/dnn/dropout.py
@@ -4,7 +4,7 @@
from brainpy._src.context import share
from brainpy import math as bm, check
-from .base import Layer
+from brainpy._src.layer import Layer
__all__ = [
'Dropout'
diff --git a/brainpy/_src/dnn/function.py b/brainpy/_src/dnn/function.py
index b4a39f6f2..7d12246b4 100644
--- a/brainpy/_src/dnn/function.py
+++ b/brainpy/_src/dnn/function.py
@@ -5,7 +5,7 @@
import brainpy.math as bm
from brainpy import check
-from .base import Layer
+from brainpy._src.layer import Layer
__all__ = [
'Activation',
diff --git a/brainpy/_src/dnn/interoperation_flax.py b/brainpy/_src/dnn/interoperation_flax.py
index b0c9c01ac..ce98964fc 100644
--- a/brainpy/_src/dnn/interoperation_flax.py
+++ b/brainpy/_src/dnn/interoperation_flax.py
@@ -7,7 +7,7 @@
from brainpy import math as bm
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.context import share
-from .base import Layer
+from brainpy._src.layer import Layer
try:
import flax # noqa
diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py
index a5faccc10..b4f638fca 100644
--- a/brainpy/_src/dnn/linear.py
+++ b/brainpy/_src/dnn/linear.py
@@ -14,7 +14,7 @@
from brainpy.errors import MathError
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.types import ArrayType, Sharding
-from .base import Layer
+from brainpy._src.layer import Layer
__all__ = [
'Dense', 'Linear',
diff --git a/brainpy/_src/dnn/normalization.py b/brainpy/_src/dnn/normalization.py
index 38e59d061..e99e162c3 100644
--- a/brainpy/_src/dnn/normalization.py
+++ b/brainpy/_src/dnn/normalization.py
@@ -8,7 +8,7 @@
from brainpy import math as bm, check
from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter
from brainpy.types import ArrayType
-from .base import Layer
+from brainpy._src.layer import Layer
__all__ = [
'BatchNorm1d',
diff --git a/brainpy/_src/dnn/nvar.py b/brainpy/_src/dnn/nvar.py
index b2eab7eca..da1f6ed48 100644
--- a/brainpy/_src/dnn/nvar.py
+++ b/brainpy/_src/dnn/nvar.py
@@ -8,7 +8,7 @@
import brainpy.math as bm
from brainpy import check
-from .base import Layer
+from brainpy._src.layer import Layer
__all__ = [
'NVAR'
diff --git a/brainpy/_src/dnn/pooling.py b/brainpy/_src/dnn/pooling.py
index 3ff24d8a4..3bb38ff3b 100644
--- a/brainpy/_src/dnn/pooling.py
+++ b/brainpy/_src/dnn/pooling.py
@@ -7,7 +7,7 @@
import numpy as np
from brainpy import math as bm, check
-from .base import Layer
+from brainpy._src.layer import Layer
__all__ = [
'MaxPool',
diff --git a/brainpy/_src/dnn/reservoir.py b/brainpy/_src/dnn/reservoir.py
index 6cab48a29..c5ea3cb5a 100644
--- a/brainpy/_src/dnn/reservoir.py
+++ b/brainpy/_src/dnn/reservoir.py
@@ -9,7 +9,7 @@
from brainpy import check
from brainpy.tools import to_size
from brainpy.types import ArrayType
-from .base import Layer
+from brainpy._src.layer import Layer
__all__ = [
'Reservoir',
diff --git a/brainpy/_src/dnn/rnncells.py b/brainpy/_src/dnn/rnncells.py
index d3feb9276..2df1b4a76 100644
--- a/brainpy/_src/dnn/rnncells.py
+++ b/brainpy/_src/dnn/rnncells.py
@@ -7,7 +7,7 @@
import brainpy.math as bm
from brainpy.math import activations
-from .base import Layer
+from brainpy._src.layer import Layer
from brainpy.check import (is_integer,
is_initializer)
from brainpy.initialize import (XavierNormal,
diff --git a/brainpy/_src/dnn/base.py b/brainpy/_src/layer.py
similarity index 100%
rename from brainpy/_src/dnn/base.py
rename to brainpy/_src/layer.py
diff --git a/brainpy/_src/losses/base.py b/brainpy/_src/losses/base.py
index a01e2aee8..e8f6434fa 100644
--- a/brainpy/_src/losses/base.py
+++ b/brainpy/_src/losses/base.py
@@ -1,6 +1,6 @@
from typing import Optional
-from brainpy._src.dnn.base import Layer
+from brainpy._src.layer import Layer
__all__ = [
'Loss',
diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py
index 5fed869ff..143c8884f 100644
--- a/brainpy/_src/mixin.py
+++ b/brainpy/_src/mixin.py
@@ -1,6 +1,8 @@
import numbers
+import sys
from dataclasses import dataclass
from typing import Union, Dict, Callable, Sequence, Optional, TypeVar
+from typing import (_SpecialForm, _type_check, _remove_dups_flatten)
import jax
import jax.numpy as jnp
@@ -8,10 +10,14 @@
from brainpy import math as bm, tools
from brainpy._src.initialize import parameter
-from brainpy._src.python_typing_copied import (_SpecialForm, _UnionGenericAlias,
- _type_check, _remove_dups_flatten)
+
from brainpy.types import ArrayType
+if sys.version_info.minor > 8:
+ from typing import (_UnionGenericAlias)
+else:
+ from typing import (_GenericAlias, _tp_cache)
+
DynamicalSystem = None
__all__ = [
@@ -469,46 +475,99 @@ def __class_getitem__(cls, types: Union[type, Sequence[type]]) -> type:
return _MetaUnionType('UnionType', types, {})
-class _JointGenericAlias(_UnionGenericAlias, _root=True):
- def __subclasscheck__(self, subclass):
- return all([issubclass(subclass, cls) for cls in set(self.__args__)])
-
+if sys.version_info.minor > 8:
+ class _JointGenericAlias(_UnionGenericAlias, _root=True):
+ def __subclasscheck__(self, subclass):
+ return all([issubclass(subclass, cls) for cls in set(self.__args__)])
-@_SpecialForm
-def JointType(self, parameters):
- """Joint type; JointType[X, Y] means either X or Y.
- To define a union, use e.g. Union[int, str]. Details:
- - The arguments must be types and there must be at least one.
- - None as an argument is a special case and is replaced by
- type(None).
- - Unions of unions are flattened, e.g.::
+ @_SpecialForm
+ def JointType(self, parameters):
+ """Joint type; JointType[X, Y] means either X or Y.
- JointType[JointType[int, str], float] == JointType[int, str, float]
+ To define a union, use e.g. Union[int, str]. Details:
+ - The arguments must be types and there must be at least one.
+ - None as an argument is a special case and is replaced by
+ type(None).
+ - Unions of unions are flattened, e.g.::
- - Unions of a single argument vanish, e.g.::
+ JointType[JointType[int, str], float] == JointType[int, str, float]
- JointType[int] == int # The constructor actually returns int
+ - Unions of a single argument vanish, e.g.::
- - Redundant arguments are skipped, e.g.::
+ JointType[int] == int # The constructor actually returns int
- JointType[int, str, int] == JointType[int, str]
+ - Redundant arguments are skipped, e.g.::
- - When comparing unions, the argument order is ignored, e.g.::
+ JointType[int, str, int] == JointType[int, str]
- JointType[int, str] == JointType[str, int]
+ - When comparing unions, the argument order is ignored, e.g.::
- - You cannot subclass or instantiate a union.
- - You can use Optional[X] as a shorthand for JointType[X, None].
- """
- if parameters == ():
- raise TypeError("Cannot take a Union of no types.")
- if not isinstance(parameters, tuple):
- parameters = (parameters,)
- msg = "JointType[arg, ...]: each arg must be a type."
- parameters = tuple(_type_check(p, msg) for p in parameters)
- parameters = _remove_dups_flatten(parameters)
- if len(parameters) == 1:
- return parameters[0]
- return _JointGenericAlias(self, parameters)
+ JointType[int, str] == JointType[str, int]
+ - You cannot subclass or instantiate a union.
+ - You can use Optional[X] as a shorthand for JointType[X, None].
+ """
+ if parameters == ():
+ raise TypeError("Cannot take a Union of no types.")
+ if not isinstance(parameters, tuple):
+ parameters = (parameters,)
+ msg = "JointType[arg, ...]: each arg must be a type."
+ parameters = tuple(_type_check(p, msg) for p in parameters)
+ parameters = _remove_dups_flatten(parameters)
+ if len(parameters) == 1:
+ return parameters[0]
+ return _JointGenericAlias(self, parameters)
+
+else:
+ class _JointGenericAlias(_GenericAlias, _root=True):
+ def __subclasscheck__(self, subclass):
+ return all([issubclass(subclass, cls) for cls in set(self.__args__)])
+
+
+ class _SpecialForm2(_SpecialForm, _root=True):
+ @_tp_cache
+ def __getitem__(self, parameters):
+ if self._name == 'JointType':
+ if parameters == ():
+ raise TypeError("Cannot take a Union of no types.")
+ if not isinstance(parameters, tuple):
+ parameters = (parameters,)
+ msg = "Union[arg, ...]: each arg must be a type."
+ parameters = tuple(_type_check(p, msg) for p in parameters)
+ parameters = _remove_dups_flatten(parameters)
+ if len(parameters) == 1:
+ return parameters[0]
+ return _JointGenericAlias(self, parameters)
+ else:
+ return super().__getitem__(parameters)
+
+
+ JointType = _SpecialForm2(
+ 'JointType',
+ doc="""Joint type; JointType[X, Y] means either X or Y.
+
+ To define a union, use e.g. JointType[int, str]. Details:
+ - The arguments must be types and there must be at least one.
+ - None as an argument is a special case and is replaced by
+ type(None).
+ - Unions of unions are flattened, e.g.::
+
+ JointType[JointType[int, str], float] == JointType[int, str, float]
+
+ - Unions of a single argument vanish, e.g.::
+
+ JointType[int] == int # The constructor actually returns int
+
+ - Redundant arguments are skipped, e.g.::
+
+ JointType[int, str, int] == JointType[int, str]
+
+ - When comparing unions, the argument order is ignored, e.g.::
+
+ JointType[int, str] == JointType[str, int]
+
+ - You cannot subclass or instantiate a union.
+ - You can use Optional[X] as a shorthand for JointType[X, None].
+ """
+ )
diff --git a/brainpy/_src/python_typing_copied.py b/brainpy/_src/python_typing_copied.py
deleted file mode 100644
index 8e9b25276..000000000
--- a/brainpy/_src/python_typing_copied.py
+++ /dev/null
@@ -1,2273 +0,0 @@
-"""
-The typing module: Support for gradual typing as defined by PEP 484.
-
-At large scale, the structure of the module is following:
-* Imports and exports, all public names should be explicitly added to __all__.
-* Internal helper functions: these should never be used in code outside this module.
-* _SpecialForm and its instances (special forms): Any, NoReturn, ClassVar, Union, Optional
-* Two classes whose instances can be type arguments in addition to types: ForwardRef and TypeVar
-* The core of internal generics API: _GenericAlias and _VariadicGenericAlias, the latter is
- currently only used by Tuple and Callable. All subscripted types like X[int], Union[int, str],
- etc., are instances of either of these classes.
-* The public counterpart of the generics API consists of two classes: Generic and Protocol.
-* Public helper functions: get_type_hints, overload, cast, no_type_check,
- no_type_check_decorator.
-* Generic aliases for collections.abc ABCs and few additional protocols.
-* Special types: NewType, NamedTuple, TypedDict.
-* Wrapper submodules for re and io related types.
-"""
-
-from abc import abstractmethod, ABCMeta
-import collections
-import collections.abc
-import contextlib
-import functools
-import operator
-import re as stdlib_re # Avoid confusion with the re we export.
-import sys
-import types
-from types import WrapperDescriptorType, MethodWrapperType, MethodDescriptorType, GenericAlias
-
-# Please keep __all__ alphabetized within each category.
-__all__ = [
- # Super-special typing primitives.
- 'Annotated',
- 'Any',
- 'Callable',
- 'ClassVar',
- 'Final',
- 'ForwardRef',
- 'Generic',
- 'Literal',
- 'Optional',
- 'Protocol',
- 'Tuple',
- 'Type',
- 'TypeVar',
- 'Union',
-
- # ABCs (from collections.abc).
- 'AbstractSet', # collections.abc.Set.
- 'ByteString',
- 'Container',
- 'ContextManager',
- 'Hashable',
- 'ItemsView',
- 'Iterable',
- 'Iterator',
- 'KeysView',
- 'Mapping',
- 'MappingView',
- 'MutableMapping',
- 'MutableSequence',
- 'MutableSet',
- 'Sequence',
- 'Sized',
- 'ValuesView',
- 'Awaitable',
- 'AsyncIterator',
- 'AsyncIterable',
- 'Coroutine',
- 'Collection',
- 'AsyncGenerator',
- 'AsyncContextManager',
-
- # Structural checks, a.k.a. protocols.
- 'Reversible',
- 'SupportsAbs',
- 'SupportsBytes',
- 'SupportsComplex',
- 'SupportsFloat',
- 'SupportsIndex',
- 'SupportsInt',
- 'SupportsRound',
-
- # Concrete collection types.
- 'ChainMap',
- 'Counter',
- 'Deque',
- 'Dict',
- 'DefaultDict',
- 'List',
- 'OrderedDict',
- 'Set',
- 'FrozenSet',
- 'NamedTuple', # Not really a type.
- 'TypedDict', # Not really a type.
- 'Generator',
-
- # Other concrete types.
- 'BinaryIO',
- 'IO',
- 'Match',
- 'Pattern',
- 'TextIO',
-
- # One-off things.
- 'AnyStr',
- 'cast',
- 'final',
- 'get_args',
- 'get_origin',
- 'get_type_hints',
- 'NewType',
- 'no_type_check',
- 'no_type_check_decorator',
- 'NoReturn',
- 'overload',
- 'runtime_checkable',
- 'Text',
- 'TYPE_CHECKING',
-]
-
-
-# The pseudo-submodules 're' and 'io' are part of the public
-# namespace, but excluded from __all__ because they might stomp on
-# legitimate imports of those modules.
-
-
-def _type_convert(arg, module=None, *, allow_special_forms=False):
- """For converting None to type(None), and strings to ForwardRef."""
- if arg is None:
- return type(None)
- if isinstance(arg, str):
- return ForwardRef(arg, module=module, is_class=allow_special_forms)
- return arg
-
-
-def _type_check(arg, msg, is_argument=True, module=None, *, allow_special_forms=False):
- """Check that the argument is a type, and return it (internal helper).
-
- As a special case, accept None and return type(None) instead. Also wrap strings
- into ForwardRef instances. Consider several corner cases, for example plain
- special forms like Union are not valid, while Union[int, str] is OK, etc.
- The msg argument is a human-readable error message, e.g::
-
- "Union[arg, ...]: arg should be a type."
-
- We append the repr() of the actual value (truncated to 100 chars).
- """
- invalid_generic_forms = (Generic, Protocol)
- if not allow_special_forms:
- invalid_generic_forms += (ClassVar,)
- if is_argument:
- invalid_generic_forms += (Final,)
-
- arg = _type_convert(arg, module=module, allow_special_forms=allow_special_forms)
- if (isinstance(arg, _GenericAlias) and
- arg.__origin__ in invalid_generic_forms):
- raise TypeError(f"{arg} is not valid as type argument")
- if arg in (Any, NoReturn, Final):
- return arg
- if isinstance(arg, _SpecialForm) or arg in (Generic, Protocol):
- raise TypeError(f"Plain {arg} is not valid as type argument")
- if isinstance(arg, (type, TypeVar, ForwardRef)):
- return arg
- if not callable(arg):
- raise TypeError(f"{msg} Got {arg!r:.100}.")
- return arg
-
-
-def _type_repr(obj):
- """Return the repr() of an object, special-casing types (internal helper).
-
- If obj is a type, we return a shorter version than the default
- type.__repr__, based on the module and qualified name, which is
- typically enough to uniquely identify a type. For everything
- else, we fall back on repr(obj).
- """
- if isinstance(obj, types.GenericAlias):
- return repr(obj)
- if isinstance(obj, type):
- if obj.__module__ == 'builtins':
- return obj.__qualname__
- return f'{obj.__module__}.{obj.__qualname__}'
- if obj is ...:
- return ('...')
- if isinstance(obj, types.FunctionType):
- return obj.__name__
- return repr(obj)
-
-
-def _collect_type_vars(types):
- """Collect all type variable contained in types in order of
- first appearance (lexicographic order). For example::
-
- _collect_type_vars((T, List[S, T])) == (T, S)
- """
- tvars = []
- for t in types:
- if isinstance(t, TypeVar) and t not in tvars:
- tvars.append(t)
- if isinstance(t, (_GenericAlias, GenericAlias)):
- tvars.extend([t for t in t.__parameters__ if t not in tvars])
- return tuple(tvars)
-
-
-def _check_generic(cls, parameters, elen):
- """Check correct count for parameters of a generic cls (internal helper).
- This gives a nice error message in case of count mismatch.
- """
- if not elen:
- raise TypeError(f"{cls} is not a generic class")
- alen = len(parameters)
- if alen != elen:
- raise TypeError(f"Too {'many' if alen > elen else 'few'} parameters for {cls};"
- f" actual {alen}, expected {elen}")
-
-
-def _deduplicate(params):
- # Weed out strict duplicates, preserving the first of each occurrence.
- all_params = set(params)
- if len(all_params) < len(params):
- new_params = []
- for t in params:
- if t in all_params:
- new_params.append(t)
- all_params.remove(t)
- params = new_params
- assert not all_params, all_params
- return params
-
-
-def _remove_dups_flatten(parameters):
- """An internal helper for Union creation and substitution: flatten Unions
- among parameters, then remove duplicates.
- """
- # Flatten out Union[Union[...], ...].
- params = []
- for p in parameters:
- if isinstance(p, _UnionGenericAlias):
- params.extend(p.__args__)
- elif isinstance(p, tuple) and len(p) > 0 and p[0] is Union:
- params.extend(p[1:])
- else:
- params.append(p)
-
- return tuple(_deduplicate(params))
-
-
-def _flatten_literal_params(parameters):
- """An internal helper for Literal creation: flatten Literals among parameters"""
- params = []
- for p in parameters:
- if isinstance(p, _LiteralGenericAlias):
- params.extend(p.__args__)
- else:
- params.append(p)
- return tuple(params)
-
-
-_cleanups = []
-
-
-def _tp_cache(func=None, /, *, typed=False):
- """Internal wrapper caching __getitem__ of generic types with a fallback to
- original function for non-hashable arguments.
- """
-
- def decorator(func):
- cached = functools.lru_cache(typed=typed)(func)
- _cleanups.append(cached.cache_clear)
-
- @functools.wraps(func)
- def inner(*args, **kwds):
- try:
- return cached(*args, **kwds)
- except TypeError:
- pass # All real errors (not unhashable args) are raised below.
- return func(*args, **kwds)
-
- return inner
-
- if func is not None:
- return decorator(func)
-
- return decorator
-
-
-def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
- """Evaluate all forward references in the given type t.
- For use of globalns and localns see the docstring for get_type_hints().
- recursive_guard is used to prevent infinite recursion with a recursive
- ForwardRef.
- """
- if isinstance(t, ForwardRef):
- return t._evaluate(globalns, localns, recursive_guard)
- if isinstance(t, (_GenericAlias, GenericAlias)):
- ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
- if ev_args == t.__args__:
- return t
- if isinstance(t, GenericAlias):
- return GenericAlias(t.__origin__, ev_args)
- else:
- return t.copy_with(ev_args)
- return t
-
-
-class _Final:
- """Mixin to prohibit subclassing"""
-
- __slots__ = ('__weakref__',)
-
- def __init_subclass__(self, /, *args, **kwds):
- if '_root' not in kwds:
- raise TypeError("Cannot subclass special typing classes")
-
-
-class _Immutable:
- """Mixin to indicate that object should not be copied."""
- __slots__ = ()
-
- def __copy__(self):
- return self
-
- def __deepcopy__(self, memo):
- return self
-
-
-# Internal indicator of special typing constructs.
-# See __doc__ instance attribute for specific docs.
-class _SpecialForm(_Final, _root=True):
- __slots__ = ('_name', '__doc__', '_getitem')
-
- def __init__(self, getitem):
- self._getitem = getitem
- self._name = getitem.__name__
- self.__doc__ = getitem.__doc__
-
- def __mro_entries__(self, bases):
- raise TypeError(f"Cannot subclass {self!r}")
-
- def __repr__(self):
- return 'typing.' + self._name
-
- def __reduce__(self):
- return self._name
-
- def __call__(self, *args, **kwds):
- raise TypeError(f"Cannot instantiate {self!r}")
-
- def __instancecheck__(self, obj):
- raise TypeError(f"{self} cannot be used with isinstance()")
-
- def __subclasscheck__(self, cls):
- raise TypeError(f"{self} cannot be used with issubclass()")
-
- @_tp_cache
- def __getitem__(self, parameters):
- return self._getitem(self, parameters)
-
-
-class _LiteralSpecialForm(_SpecialForm, _root=True):
- def __getitem__(self, parameters):
- if not isinstance(parameters, tuple):
- parameters = (parameters,)
- return self._getitem(self, *parameters)
-
-
-@_SpecialForm
-def Any(self, parameters):
- """Special type indicating an unconstrained type.
-
- - Any is compatible with every type.
- - Any assumed to have all methods.
- - All values assumed to be instances of Any.
-
- Note that all the above statements are true from the point of view of
- static type checkers. At runtime, Any should not be used with instance
- or class checks.
- """
- raise TypeError(f"{self} is not subscriptable")
-
-
-@_SpecialForm
-def NoReturn(self, parameters):
- """Special type indicating functions that never return.
- Example::
-
- from typing import NoReturn
-
- def stop() -> NoReturn:
- raise Exception('no way')
-
- This type is invalid in other positions, e.g., ``List[NoReturn]``
- will fail in static type checkers.
- """
- raise TypeError(f"{self} is not subscriptable")
-
-
-@_SpecialForm
-def ClassVar(self, parameters):
- """Special type construct to mark class variables.
-
- An annotation wrapped in ClassVar indicates that a given
- attribute is intended to be used as a class variable and
- should not be set on instances of that class. Usage::
-
- class Starship:
- stats: ClassVar[Dict[str, int]] = {} # class variable
- damage: int = 10 # instance variable
-
- ClassVar accepts only types and cannot be further subscribed.
-
- Note that ClassVar is not a class itself, and should not
- be used with isinstance() or issubclass().
- """
- item = _type_check(parameters, f'{self} accepts only single type.')
- return _GenericAlias(self, (item,))
-
-
-@_SpecialForm
-def Final(self, parameters):
- """Special typing construct to indicate final names to type checkers.
-
- A final name cannot be re-assigned or overridden in a subclass.
- For example:
-
- MAX_SIZE: Final = 9000
- MAX_SIZE += 1 # Error reported by type checker
-
- class Connection:
- TIMEOUT: Final[int] = 10
-
- class FastConnector(Connection):
- TIMEOUT = 1 # Error reported by type checker
-
- There is no runtime checking of these properties.
- """
- item = _type_check(parameters, f'{self} accepts only single type.')
- return _GenericAlias(self, (item,))
-
-
-@_SpecialForm
-def Union(self, parameters):
- """Union type; Union[X, Y] means either X or Y.
-
- To define a union, use e.g. Union[int, str]. Details:
- - The arguments must be types and there must be at least one.
- - None as an argument is a special case and is replaced by
- type(None).
- - Unions of unions are flattened, e.g.::
-
- Union[Union[int, str], float] == Union[int, str, float]
-
- - Unions of a single argument vanish, e.g.::
-
- Union[int] == int # The constructor actually returns int
-
- - Redundant arguments are skipped, e.g.::
-
- Union[int, str, int] == Union[int, str]
-
- - When comparing unions, the argument order is ignored, e.g.::
-
- Union[int, str] == Union[str, int]
-
- - You cannot subclass or instantiate a union.
- - You can use Optional[X] as a shorthand for Union[X, None].
- """
- if parameters == ():
- raise TypeError("Cannot take a Union of no types.")
- if not isinstance(parameters, tuple):
- parameters = (parameters,)
- msg = "Union[arg, ...]: each arg must be a type."
- parameters = tuple(_type_check(p, msg) for p in parameters)
- parameters = _remove_dups_flatten(parameters)
- if len(parameters) == 1:
- return parameters[0]
- return _UnionGenericAlias(self, parameters)
-
-
-@_SpecialForm
-def Optional(self, parameters):
- """Optional type.
-
- Optional[X] is equivalent to Union[X, None].
- """
- arg = _type_check(parameters, f"{self} requires a single type.")
- return Union[arg, type(None)]
-
-
-@_LiteralSpecialForm
-@_tp_cache(typed=True)
-def Literal(self, *parameters):
- """Special typing form to define literal types (a.k.a. value types).
-
- This form can be used to indicate to type checkers that the corresponding
- variable or function parameter has a value equivalent to the provided
- literal (or one of several literals):
-
- def validate_simple(data: Any) -> Literal[True]: # always returns True
- ...
-
- MODE = Literal['r', 'rb', 'w', 'wb']
- def open_helper(file: str, mode: MODE) -> str:
- ...
-
- open_helper('/some/path', 'r') # Passes type check
- open_helper('/other/path', 'typo') # Error in type checker
-
- Literal[...] cannot be subclassed. At runtime, an arbitrary value
- is allowed as type argument to Literal[...], but type checkers may
- impose restrictions.
- """
- # There is no '_type_check' call because arguments to Literal[...] are
- # values, not types.
- parameters = _flatten_literal_params(parameters)
-
- try:
- parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
- except TypeError: # unhashable parameters
- pass
-
- return _LiteralGenericAlias(self, parameters)
-
-
-class ForwardRef(_Final, _root=True):
- """Internal wrapper to hold a forward reference."""
-
- __slots__ = ('__forward_arg__', '__forward_code__',
- '__forward_evaluated__', '__forward_value__',
- '__forward_is_argument__', '__forward_is_class__',
- '__forward_module__')
-
- def __init__(self, arg, is_argument=True, module=None, *, is_class=False):
- if not isinstance(arg, str):
- raise TypeError(f"Forward reference must be a string -- got {arg!r}")
- try:
- code = compile(arg, '', 'eval')
- except SyntaxError:
- raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}")
- self.__forward_arg__ = arg
- self.__forward_code__ = code
- self.__forward_evaluated__ = False
- self.__forward_value__ = None
- self.__forward_is_argument__ = is_argument
- self.__forward_is_class__ = is_class
- self.__forward_module__ = module
-
- def _evaluate(self, globalns, localns, recursive_guard):
- if self.__forward_arg__ in recursive_guard:
- return self
- if not self.__forward_evaluated__ or localns is not globalns:
- if globalns is None and localns is None:
- globalns = localns = {}
- elif globalns is None:
- globalns = localns
- elif localns is None:
- localns = globalns
- if self.__forward_module__ is not None:
- globalns = getattr(
- sys.modules.get(self.__forward_module__, None), '__dict__', globalns
- )
- type_ = _type_check(
- eval(self.__forward_code__, globalns, localns),
- "Forward references must evaluate to types.",
- is_argument=self.__forward_is_argument__,
- allow_special_forms=self.__forward_is_class__,
- )
- self.__forward_value__ = _eval_type(
- type_, globalns, localns, recursive_guard | {self.__forward_arg__}
- )
- self.__forward_evaluated__ = True
- return self.__forward_value__
-
- def __eq__(self, other):
- if not isinstance(other, ForwardRef):
- return NotImplemented
- if self.__forward_evaluated__ and other.__forward_evaluated__:
- return (self.__forward_arg__ == other.__forward_arg__ and
- self.__forward_value__ == other.__forward_value__)
- return (self.__forward_arg__ == other.__forward_arg__ and
- self.__forward_module__ == other.__forward_module__)
-
- def __hash__(self):
- return hash((self.__forward_arg__, self.__forward_module__))
-
- def __repr__(self):
- return f'ForwardRef({self.__forward_arg__!r})'
-
-
-class TypeVar(_Final, _Immutable, _root=True):
- """Type variable.
-
- Usage::
-
- T = TypeVar('T') # Can be anything
- A = TypeVar('A', str, bytes) # Must be str or bytes
-
- Type variables exist primarily for the benefit of static type
- checkers. They serve as the parameters for generic types as well
- as for generic function definitions. See class Generic for more
- information on generic types. Generic functions work as follows:
-
- def repeat(x: T, n: int) -> List[T]:
- '''Return a list containing n references to x.'''
- return [x]*n
-
- def longest(x: A, y: A) -> A:
- '''Return the longest of two strings.'''
- return x if len(x) >= len(y) else y
-
- The latter example's signature is essentially the overloading
- of (str, str) -> str and (bytes, bytes) -> bytes. Also note
- that if the arguments are instances of some subclass of str,
- the return type is still plain str.
-
- At runtime, isinstance(x, T) and issubclass(C, T) will raise TypeError.
-
- Type variables defined with covariant=True or contravariant=True
- can be used to declare covariant or contravariant generic types.
- See PEP 484 for more details. By default generic types are invariant
- in all type variables.
-
- Type variables can be introspected. e.g.:
-
- T.__name__ == 'T'
- T.__constraints__ == ()
- T.__covariant__ == False
- T.__contravariant__ = False
- A.__constraints__ == (str, bytes)
-
- Note that only type variables defined in global scope can be pickled.
- """
-
- __slots__ = ('__name__', '__bound__', '__constraints__',
- '__covariant__', '__contravariant__', '__dict__')
-
- def __init__(self, name, *constraints, bound=None,
- covariant=False, contravariant=False):
- self.__name__ = name
- if covariant and contravariant:
- raise ValueError("Bivariant types are not supported.")
- self.__covariant__ = bool(covariant)
- self.__contravariant__ = bool(contravariant)
- if constraints and bound is not None:
- raise TypeError("Constraints cannot be combined with bound=...")
- if constraints and len(constraints) == 1:
- raise TypeError("A single constraint is not allowed")
- msg = "TypeVar(name, constraint, ...): constraints must be types."
- self.__constraints__ = tuple(_type_check(t, msg) for t in constraints)
- if bound:
- self.__bound__ = _type_check(bound, "Bound must be a type.")
- else:
- self.__bound__ = None
- try:
- def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') # for pickling
- except (AttributeError, ValueError):
- def_mod = None
- if def_mod != 'typing':
- self.__module__ = def_mod
-
- def __repr__(self):
- if self.__covariant__:
- prefix = '+'
- elif self.__contravariant__:
- prefix = '-'
- else:
- prefix = '~'
- return prefix + self.__name__
-
- def __reduce__(self):
- return self.__name__
-
-
-def _is_dunder(attr):
- return attr.startswith('__') and attr.endswith('__')
-
-
-class _BaseGenericAlias(_Final, _root=True):
- """The central part of internal API.
-
- This represents a generic version of type 'origin' with type arguments 'params'.
- There are two kind of these aliases: user defined and special. The special ones
- are wrappers around builtin collections and ABCs in collections.abc. These must
- have 'name' always set. If 'inst' is False, then the alias can't be instantiated,
- this is used by e.g. typing.List and typing.Dict.
- """
-
- def __init__(self, origin, *, inst=True, name=None):
- self._inst = inst
- self._name = name
- self.__origin__ = origin
- self.__slots__ = None # This is not documented.
-
- def __call__(self, *args, **kwargs):
- if not self._inst:
- raise TypeError(f"Type {self._name} cannot be instantiated; "
- f"use {self.__origin__.__name__}() instead")
- result = self.__origin__(*args, **kwargs)
- try:
- result.__orig_class__ = self
- except AttributeError:
- pass
- return result
-
- def __mro_entries__(self, bases):
- res = []
- if self.__origin__ not in bases:
- res.append(self.__origin__)
- i = bases.index(self)
- for b in bases[i + 1:]:
- if isinstance(b, _BaseGenericAlias) or issubclass(b, Generic):
- break
- else:
- res.append(Generic)
- return tuple(res)
-
- def __getattr__(self, attr):
- # We are careful for copy and pickle.
- # Also for simplicity we don't relay any dunder names
- if '__origin__' in self.__dict__ and not _is_dunder(attr):
- return getattr(self.__origin__, attr)
- raise AttributeError(attr)
-
- def __setattr__(self, attr, val):
- if _is_dunder(attr) or attr in ('_name', '_inst', '_nparams'):
- super().__setattr__(attr, val)
- else:
- setattr(self.__origin__, attr, val)
-
- def __instancecheck__(self, obj):
- return self.__subclasscheck__(type(obj))
-
- def __subclasscheck__(self, cls):
- raise TypeError("Subscripted generics cannot be used with"
- " class and instance checks")
-
-
-# Special typing constructs Union, Optional, Generic, Callable and Tuple
-# use three special attributes for internal bookkeeping of generic types:
-# * __parameters__ is a tuple of unique free type parameters of a generic
-# type, for example, Dict[T, T].__parameters__ == (T,);
-# * __origin__ keeps a reference to a type that was subscripted,
-# e.g., Union[T, int].__origin__ == Union, or the non-generic version of
-# the type.
-# * __args__ is a tuple of all arguments used in subscripting,
-# e.g., Dict[T, int].__args__ == (T, int).
-
-
-class _GenericAlias(_BaseGenericAlias, _root=True):
- def __init__(self, origin, params, *, inst=True, name=None):
- super().__init__(origin, inst=inst, name=name)
- if not isinstance(params, tuple):
- params = (params,)
- self.__args__ = tuple(... if a is _TypingEllipsis else
- () if a is _TypingEmpty else
- a for a in params)
- self.__parameters__ = _collect_type_vars(params)
- if not name:
- self.__module__ = origin.__module__
-
- def __eq__(self, other):
- if not isinstance(other, _GenericAlias):
- return NotImplemented
- return (self.__origin__ == other.__origin__
- and self.__args__ == other.__args__)
-
- def __hash__(self):
- return hash((self.__origin__, self.__args__))
-
- @_tp_cache
- def __getitem__(self, params):
- if self.__origin__ in (Generic, Protocol):
- # Can't subscript Generic[...] or Protocol[...].
- raise TypeError(f"Cannot subscript already-subscripted {self}")
- if not isinstance(params, tuple):
- params = (params,)
- msg = "Parameters to generic types must be types."
- params = tuple(_type_check(p, msg) for p in params)
- _check_generic(self, params, len(self.__parameters__))
-
- subst = dict(zip(self.__parameters__, params))
- new_args = []
- for arg in self.__args__:
- if isinstance(arg, TypeVar):
- arg = subst[arg]
- elif isinstance(arg, (_GenericAlias, GenericAlias)):
- subparams = arg.__parameters__
- if subparams:
- subargs = tuple(subst[x] for x in subparams)
- arg = arg[subargs]
- new_args.append(arg)
- return self.copy_with(tuple(new_args))
-
- def copy_with(self, params):
- return self.__class__(self.__origin__, params, name=self._name, inst=self._inst)
-
- def __repr__(self):
- if self._name:
- name = 'typing.' + self._name
- else:
- name = _type_repr(self.__origin__)
- args = ", ".join([_type_repr(a) for a in self.__args__])
- return f'{name}[{args}]'
-
- def __reduce__(self):
- if self._name:
- origin = globals()[self._name]
- else:
- origin = self.__origin__
- args = tuple(self.__args__)
- if len(args) == 1 and not isinstance(args[0], tuple):
- args, = args
- return operator.getitem, (origin, args)
-
- def __mro_entries__(self, bases):
- if self._name: # generic version of an ABC or built-in class
- return super().__mro_entries__(bases)
- if self.__origin__ is Generic:
- if Protocol in bases:
- return ()
- i = bases.index(self)
- for b in bases[i + 1:]:
- if isinstance(b, _BaseGenericAlias) and b is not self:
- return ()
- return (self.__origin__,)
-
-
-# _nparams is the number of accepted parameters, e.g. 0 for Hashable,
-# 1 for List and 2 for Dict. It may be -1 if variable number of
-# parameters are accepted (needs custom __getitem__).
-
-class _SpecialGenericAlias(_BaseGenericAlias, _root=True):
- def __init__(self, origin, nparams, *, inst=True, name=None):
- if name is None:
- name = origin.__name__
- super().__init__(origin, inst=inst, name=name)
- self._nparams = nparams
- if origin.__module__ == 'builtins':
- self.__doc__ = f'A generic version of {origin.__qualname__}.'
- else:
- self.__doc__ = f'A generic version of {origin.__module__}.{origin.__qualname__}.'
-
- @_tp_cache
- def __getitem__(self, params):
- if not isinstance(params, tuple):
- params = (params,)
- msg = "Parameters to generic types must be types."
- params = tuple(_type_check(p, msg) for p in params)
- _check_generic(self, params, self._nparams)
- return self.copy_with(params)
-
- def copy_with(self, params):
- return _GenericAlias(self.__origin__, params,
- name=self._name, inst=self._inst)
-
- def __repr__(self):
- return 'typing.' + self._name
-
- def __subclasscheck__(self, cls):
- if isinstance(cls, _SpecialGenericAlias):
- return issubclass(cls.__origin__, self.__origin__)
- if not isinstance(cls, _GenericAlias):
- return issubclass(cls, self.__origin__)
- return super().__subclasscheck__(cls)
-
- def __reduce__(self):
- return self._name
-
-
-class _CallableGenericAlias(_GenericAlias, _root=True):
- def __repr__(self):
- assert self._name == 'Callable'
- if len(self.__args__) == 2 and self.__args__[0] is Ellipsis:
- return super().__repr__()
- return (f'typing.Callable'
- f'[[{", ".join([_type_repr(a) for a in self.__args__[:-1]])}], '
- f'{_type_repr(self.__args__[-1])}]')
-
- def __reduce__(self):
- args = self.__args__
- if not (len(args) == 2 and args[0] is ...):
- args = list(args[:-1]), args[-1]
- return operator.getitem, (Callable, args)
-
-
-class _CallableType(_SpecialGenericAlias, _root=True):
- def copy_with(self, params):
- return _CallableGenericAlias(self.__origin__, params,
- name=self._name, inst=self._inst)
-
- def __getitem__(self, params):
- if not isinstance(params, tuple) or len(params) != 2:
- raise TypeError("Callable must be used as "
- "Callable[[arg, ...], result].")
- args, result = params
- # This relaxes what args can be on purpose to allow things like
- # PEP 612 ParamSpec. Responsibility for whether a user is using
- # Callable[...] properly is deferred to static type checkers.
- if isinstance(args, list):
- params = (tuple(args), result)
- else:
- params = (args, result)
- return self.__getitem_inner__(params)
-
- @_tp_cache
- def __getitem_inner__(self, params):
- args, result = params
- msg = "Callable[args, result]: result must be a type."
- result = _type_check(result, msg)
- if args is Ellipsis:
- return self.copy_with((_TypingEllipsis, result))
- if not isinstance(args, tuple):
- args = (args,)
- args = tuple(_type_convert(arg) for arg in args)
- params = args + (result,)
- return self.copy_with(params)
-
-
-class _TupleType(_SpecialGenericAlias, _root=True):
- @_tp_cache
- def __getitem__(self, params):
- if params == ():
- return self.copy_with((_TypingEmpty,))
- if not isinstance(params, tuple):
- params = (params,)
- if len(params) == 2 and params[1] is ...:
- msg = "Tuple[t, ...]: t must be a type."
- p = _type_check(params[0], msg)
- return self.copy_with((p, _TypingEllipsis))
- msg = "Tuple[t0, t1, ...]: each t must be a type."
- params = tuple(_type_check(p, msg) for p in params)
- return self.copy_with(params)
-
-
-class _UnionGenericAlias(_GenericAlias, _root=True):
- def copy_with(self, params):
- return Union[params]
-
- def __eq__(self, other):
- if not isinstance(other, _UnionGenericAlias):
- return NotImplemented
- return set(self.__args__) == set(other.__args__)
-
- def __hash__(self):
- return hash(frozenset(self.__args__))
-
- def __repr__(self):
- args = self.__args__
- if len(args) == 2:
- if args[0] is type(None):
- return f'typing.Optional[{_type_repr(args[1])}]'
- elif args[1] is type(None):
- return f'typing.Optional[{_type_repr(args[0])}]'
- return super().__repr__()
-
-
-def _value_and_type_iter(parameters):
- return ((p, type(p)) for p in parameters)
-
-
-class _LiteralGenericAlias(_GenericAlias, _root=True):
-
- def __eq__(self, other):
- if not isinstance(other, _LiteralGenericAlias):
- return NotImplemented
-
- return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))
-
- def __hash__(self):
- return hash(frozenset(_value_and_type_iter(self.__args__)))
-
-
-class Generic:
- """Abstract base class for generic types.
-
- A generic type is typically declared by inheriting from
- this class parameterized with one or more type variables.
- For example, a generic mapping type might be defined as::
-
- class Mapping(Generic[KT, VT]):
- def __getitem__(self, key: KT) -> VT:
- ...
- # Etc.
-
- This class can then be used as follows::
-
- def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT:
- try:
- return mapping[key]
- except KeyError:
- return default
- """
- __slots__ = ()
- _is_protocol = False
-
- @_tp_cache
- def __class_getitem__(cls, params):
- if not isinstance(params, tuple):
- params = (params,)
- if not params and cls is not Tuple:
- raise TypeError(
- f"Parameter list to {cls.__qualname__}[...] cannot be empty")
- msg = "Parameters to generic types must be types."
- params = tuple(_type_check(p, msg) for p in params)
- if cls in (Generic, Protocol):
- # Generic and Protocol can only be subscripted with unique type variables.
- if not all(isinstance(p, TypeVar) for p in params):
- raise TypeError(
- f"Parameters to {cls.__name__}[...] must all be type variables")
- if len(set(params)) != len(params):
- raise TypeError(
- f"Parameters to {cls.__name__}[...] must all be unique")
- else:
- # Subscripting a regular Generic subclass.
- _check_generic(cls, params, len(cls.__parameters__))
- return _GenericAlias(cls, params)
-
- def __init_subclass__(cls, *args, **kwargs):
- super().__init_subclass__(*args, **kwargs)
- tvars = []
- if '__orig_bases__' in cls.__dict__:
- error = Generic in cls.__orig_bases__
- else:
- error = Generic in cls.__bases__ and cls.__name__ != 'Protocol'
- if error:
- raise TypeError("Cannot inherit from plain Generic")
- if '__orig_bases__' in cls.__dict__:
- tvars = _collect_type_vars(cls.__orig_bases__)
- # Look for Generic[T1, ..., Tn].
- # If found, tvars must be a subset of it.
- # If not found, tvars is it.
- # Also check for and reject plain Generic,
- # and reject multiple Generic[...].
- gvars = None
- for base in cls.__orig_bases__:
- if (isinstance(base, _GenericAlias) and
- base.__origin__ is Generic):
- if gvars is not None:
- raise TypeError(
- "Cannot inherit from Generic[...] multiple types.")
- gvars = base.__parameters__
- if gvars is not None:
- tvarset = set(tvars)
- gvarset = set(gvars)
- if not tvarset <= gvarset:
- s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
- s_args = ', '.join(str(g) for g in gvars)
- raise TypeError(f"Some type variables ({s_vars}) are"
- f" not listed in Generic[{s_args}]")
- tvars = gvars
- cls.__parameters__ = tuple(tvars)
-
-
-class _TypingEmpty:
- """Internal placeholder for () or []. Used by TupleMeta and CallableMeta
- to allow empty list/tuple in specific places, without allowing them
- to sneak in where prohibited.
- """
-
-
-class _TypingEllipsis:
- """Internal placeholder for ... (ellipsis)."""
-
-
-_TYPING_INTERNALS = ['__parameters__', '__orig_bases__', '__orig_class__',
- '_is_protocol', '_is_runtime_protocol']
-
-_SPECIAL_NAMES = ['__abstractmethods__', '__annotations__', '__dict__', '__doc__',
- '__init__', '__module__', '__new__', '__slots__',
- '__subclasshook__', '__weakref__', '__class_getitem__']
-
-# These special attributes will be not collected as protocol members.
-EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS + _SPECIAL_NAMES + ['_MutableMapping__marker']
-
-
-def _get_protocol_attrs(cls):
- """Collect protocol members from a protocol class objects.
-
- This includes names actually defined in the class dictionary, as well
- as names that appear in annotations. Special names (above) are skipped.
- """
- attrs = set()
- for base in cls.__mro__[:-1]: # without object
- if base.__name__ in ('Protocol', 'Generic'):
- continue
- annotations = getattr(base, '__annotations__', {})
- for attr in list(base.__dict__.keys()) + list(annotations.keys()):
- if not attr.startswith('_abc_') and attr not in EXCLUDED_ATTRIBUTES:
- attrs.add(attr)
- return attrs
-
-
-def _is_callable_members_only(cls):
- # PEP 544 prohibits using issubclass() with protocols that have non-method members.
- return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))
-
-
-def _no_init_or_replace_init(self, *args, **kwargs):
- cls = type(self)
-
- if cls._is_protocol:
- raise TypeError('Protocols cannot be instantiated')
-
- # Already using a custom `__init__`. No need to calculate correct
- # `__init__` to call. This can lead to RecursionError. See bpo-45121.
- if cls.__init__ is not _no_init_or_replace_init:
- return
-
- # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
- # The first instantiation of the subclass will call `_no_init_or_replace_init` which
- # searches for a proper new `__init__` in the MRO. The new `__init__`
- # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
- # instantiation of the protocol subclass will thus use the new
- # `__init__` and no longer call `_no_init_or_replace_init`.
- for base in cls.__mro__:
- init = base.__dict__.get('__init__', _no_init_or_replace_init)
- if init is not _no_init_or_replace_init:
- cls.__init__ = init
- break
- else:
- # should not happen
- cls.__init__ = object.__init__
-
- cls.__init__(self, *args, **kwargs)
-
-
-def _allow_reckless_class_cheks():
- """Allow instance and class checks for special stdlib modules.
-
- The abc and functools modules indiscriminately call isinstance() and
- issubclass() on the whole MRO of a user class, which may contain protocols.
- """
- try:
- return sys._getframe(3).f_globals['__name__'] in ['abc', 'functools']
- except (AttributeError, ValueError): # For platforms without _getframe().
- return True
-
-
-_PROTO_WHITELIST = {
- 'collections.abc': [
- 'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable',
- 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible',
- ],
- 'contextlib': ['AbstractContextManager', 'AbstractAsyncContextManager'],
-}
-
-
-class _ProtocolMeta(ABCMeta):
- # This metaclass is really unfortunate and exists only because of
- # the lack of __instancehook__.
- def __instancecheck__(cls, instance):
- # We need this method for situations where attributes are
- # assigned in __init__.
- if ((not getattr(cls, '_is_protocol', False) or
- _is_callable_members_only(cls)) and
- issubclass(instance.__class__, cls)):
- return True
- if cls._is_protocol:
- if all(hasattr(instance, attr) and
- # All *methods* can be blocked by setting them to None.
- (not callable(getattr(cls, attr, None)) or
- getattr(instance, attr) is not None)
- for attr in _get_protocol_attrs(cls)):
- return True
- return super().__instancecheck__(instance)
-
-
-class Protocol(Generic, metaclass=_ProtocolMeta):
- """Base class for protocol classes.
-
- Protocol classes are defined as::
-
- class Proto(Protocol):
- def meth(self) -> int:
- ...
-
- Such classes are primarily used with static type checkers that recognize
- structural subtyping (static duck-typing), for example::
-
- class C:
- def meth(self) -> int:
- return 0
-
- def func(x: Proto) -> int:
- return x.meth()
-
- func(C()) # Passes static type check
-
- See PEP 544 for details. Protocol classes decorated with
- @typing.runtime_checkable act as simple-minded runtime protocols that check
- only the presence of given attributes, ignoring their type signatures.
- Protocol classes can be generic, they are defined as::
-
- class GenProto(Protocol[T]):
- def meth(self) -> T:
- ...
- """
- __slots__ = ()
- _is_protocol = True
- _is_runtime_protocol = False
-
- def __init_subclass__(cls, *args, **kwargs):
- super().__init_subclass__(*args, **kwargs)
-
- # Determine if this is a protocol or a concrete subclass.
- if not cls.__dict__.get('_is_protocol', False):
- cls._is_protocol = any(b is Protocol for b in cls.__bases__)
-
- # Set (or override) the protocol subclass hook.
- def _proto_hook(other):
- if not cls.__dict__.get('_is_protocol', False):
- return NotImplemented
-
- # First, perform various sanity checks.
- if not getattr(cls, '_is_runtime_protocol', False):
- if _allow_reckless_class_cheks():
- return NotImplemented
- raise TypeError("Instance and class checks can only be used with"
- " @runtime_checkable protocols")
- if not _is_callable_members_only(cls):
- if _allow_reckless_class_cheks():
- return NotImplemented
- raise TypeError("Protocols with non-method members"
- " don't support issubclass()")
- if not isinstance(other, type):
- # Same error message as for issubclass(1, int).
- raise TypeError('issubclass() arg 1 must be a class')
-
- # Second, perform the actual structural compatibility check.
- for attr in _get_protocol_attrs(cls):
- for base in other.__mro__:
- # Check if the members appears in the class dictionary...
- if attr in base.__dict__:
- if base.__dict__[attr] is None:
- return NotImplemented
- break
-
- # ...or in annotations, if it is a sub-protocol.
- annotations = getattr(base, '__annotations__', {})
- if (isinstance(annotations, collections.abc.Mapping) and
- attr in annotations and
- issubclass(other, Generic) and other._is_protocol):
- break
- else:
- return NotImplemented
- return True
-
- if '__subclasshook__' not in cls.__dict__:
- cls.__subclasshook__ = _proto_hook
-
- # We have nothing more to do for non-protocols...
- if not cls._is_protocol:
- return
-
- # ... otherwise check consistency of bases, and prohibit instantiation.
- for base in cls.__bases__:
- if not (base in (object, Generic) or
- base.__module__ in _PROTO_WHITELIST and
- base.__name__ in _PROTO_WHITELIST[base.__module__] or
- issubclass(base, Generic) and base._is_protocol):
- raise TypeError('Protocols can only inherit from other'
- ' protocols, got %r' % base)
- cls.__init__ = _no_init_or_replace_init
-
-
-class _AnnotatedAlias(_GenericAlias, _root=True):
- """Runtime representation of an annotated type.
-
- At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't'
- with extra annotations. The alias behaves like a normal typing alias,
- instantiating is the same as instantiating the underlying type, binding
- it to types is also the same.
- """
-
- def __init__(self, origin, metadata):
- if isinstance(origin, _AnnotatedAlias):
- metadata = origin.__metadata__ + metadata
- origin = origin.__origin__
- super().__init__(origin, origin)
- self.__metadata__ = metadata
-
- def copy_with(self, params):
- assert len(params) == 1
- new_type = params[0]
- return _AnnotatedAlias(new_type, self.__metadata__)
-
- def __repr__(self):
- return "typing.Annotated[{}, {}]".format(
- _type_repr(self.__origin__),
- ", ".join(repr(a) for a in self.__metadata__)
- )
-
- def __reduce__(self):
- return operator.getitem, (
- Annotated, (self.__origin__,) + self.__metadata__
- )
-
- def __eq__(self, other):
- if not isinstance(other, _AnnotatedAlias):
- return NotImplemented
- return (self.__origin__ == other.__origin__
- and self.__metadata__ == other.__metadata__)
-
- def __hash__(self):
- return hash((self.__origin__, self.__metadata__))
-
-
-class Annotated:
- """Add context specific metadata to a type.
-
- Example: Annotated[int, runtime_check.Unsigned] indicates to the
- hypothetical runtime_check module that this type is an unsigned int.
- Every other consumer of this type can ignore this metadata and treat
- this type as int.
-
- The first argument to Annotated must be a valid type.
-
- Details:
-
- - It's an error to call `Annotated` with less than two arguments.
- - Nested Annotated are flattened::
-
- Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3]
-
- - Instantiating an annotated type is equivalent to instantiating the
- underlying type::
-
- Annotated[C, Ann1](5) == C(5)
-
- - Annotated can be used as a generic type alias::
-
- Optimized = Annotated[T, runtime.Optimize()]
- Optimized[int] == Annotated[int, runtime.Optimize()]
-
- OptimizedList = Annotated[List[T], runtime.Optimize()]
- OptimizedList[int] == Annotated[List[int], runtime.Optimize()]
- """
-
- __slots__ = ()
-
- def __new__(cls, *args, **kwargs):
- raise TypeError("Type Annotated cannot be instantiated.")
-
- @_tp_cache
- def __class_getitem__(cls, params):
- if not isinstance(params, tuple) or len(params) < 2:
- raise TypeError("Annotated[...] should be used "
- "with at least two arguments (a type and an "
- "annotation).")
- msg = "Annotated[t, ...]: t must be a type."
- origin = _type_check(params[0], msg, allow_special_forms=True)
- metadata = tuple(params[1:])
- return _AnnotatedAlias(origin, metadata)
-
- def __init_subclass__(cls, *args, **kwargs):
- raise TypeError(
- "Cannot subclass {}.Annotated".format(cls.__module__)
- )
-
-
-def runtime_checkable(cls):
- """Mark a protocol class as a runtime protocol.
-
- Such protocol can be used with isinstance() and issubclass().
- Raise TypeError if applied to a non-protocol class.
- This allows a simple-minded structural check very similar to
- one trick ponies in collections.abc such as Iterable.
- For example::
-
- @runtime_checkable
- class Closable(Protocol):
- def close(self): ...
-
- assert isinstance(open('/some/file'), Closable)
-
- Warning: this will check only the presence of the required methods,
- not their type signatures!
- """
- if not issubclass(cls, Generic) or not cls._is_protocol:
- raise TypeError('@runtime_checkable can be only applied to protocol classes,'
- ' got %r' % cls)
- cls._is_runtime_protocol = True
- return cls
-
-
-def cast(typ, val):
- """Cast a value to a type.
-
- This returns the value unchanged. To the type checker this
- signals that the return value has the designated type, but at
- runtime we intentionally don't check anything (we want this
- to be as fast as possible).
- """
- return val
-
-
-def _get_defaults(func):
- """Internal helper to extract the default arguments, by name."""
- try:
- code = func.__code__
- except AttributeError:
- # Some built-in functions don't have __code__, __defaults__, etc.
- return {}
- pos_count = code.co_argcount
- arg_names = code.co_varnames
- arg_names = arg_names[:pos_count]
- defaults = func.__defaults__ or ()
- kwdefaults = func.__kwdefaults__
- res = dict(kwdefaults) if kwdefaults else {}
- pos_offset = pos_count - len(defaults)
- for name, value in zip(arg_names[pos_offset:], defaults):
- assert name not in res
- res[name] = value
- return res
-
-
-_allowed_types = (types.FunctionType, types.BuiltinFunctionType,
- types.MethodType, types.ModuleType,
- WrapperDescriptorType, MethodWrapperType, MethodDescriptorType)
-
-
-def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
- """Return type hints for an object.
-
- This is often the same as obj.__annotations__, but it handles
- forward references encoded as string literals, adds Optional[t] if a
- default value equal to None is set and recursively replaces all
- 'Annotated[T, ...]' with 'T' (unless 'include_extras=True').
-
- The argument may be a module, class, method, or function. The annotations
- are returned as a dictionary. For classes, annotations include also
- inherited members.
-
- TypeError is raised if the argument is not of a type that can contain
- annotations, and an empty dictionary is returned if no annotations are
- present.
-
- BEWARE -- the behavior of globalns and localns is counterintuitive
- (unless you are familiar with how eval() and exec() work). The
- search order is locals first, then globals.
-
- - If no dict arguments are passed, an attempt is made to use the
- globals from obj (or the respective module's globals for classes),
- and these are also used as the locals. If the object does not appear
- to have globals, an empty dictionary is used.
-
- - If one dict argument is passed, it is used for both globals and
- locals.
-
- - If two dict arguments are passed, they specify globals and
- locals, respectively.
- """
-
- if getattr(obj, '__no_type_check__', None):
- return {}
- # Classes require a special treatment.
- if isinstance(obj, type):
- hints = {}
- for base in reversed(obj.__mro__):
- if globalns is None:
- base_globals = sys.modules[base.__module__].__dict__
- else:
- base_globals = globalns
- ann = base.__dict__.get('__annotations__', {})
- for name, value in ann.items():
- if value is None:
- value = type(None)
- if isinstance(value, str):
- value = ForwardRef(value, is_argument=False, is_class=True)
- value = _eval_type(value, base_globals, localns)
- hints[name] = value
- return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}
-
- if globalns is None:
- if isinstance(obj, types.ModuleType):
- globalns = obj.__dict__
- else:
- nsobj = obj
- # Find globalns for the unwrapped object.
- while hasattr(nsobj, '__wrapped__'):
- nsobj = nsobj.__wrapped__
- globalns = getattr(nsobj, '__globals__', {})
- if localns is None:
- localns = globalns
- elif localns is None:
- localns = globalns
- hints = getattr(obj, '__annotations__', None)
- if hints is None:
- # Return empty annotations for something that _could_ have them.
- if isinstance(obj, _allowed_types):
- return {}
- else:
- raise TypeError('{!r} is not a module, class, method, '
- 'or function.'.format(obj))
- defaults = _get_defaults(obj)
- hints = dict(hints)
- for name, value in hints.items():
- if value is None:
- value = type(None)
- if isinstance(value, str):
- # class-level forward refs were handled above, this must be either
- # a module-level annotation or a function argument annotation
- value = ForwardRef(
- value,
- is_argument=not isinstance(obj, types.ModuleType),
- is_class=False,
- )
- value = _eval_type(value, globalns, localns)
- if name in defaults and defaults[name] is None:
- value = Optional[value]
- hints[name] = value
- return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}
-
-
-def _strip_annotations(t):
- """Strips the annotations from a given type.
- """
- if isinstance(t, _AnnotatedAlias):
- return _strip_annotations(t.__origin__)
- if isinstance(t, _GenericAlias):
- stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
- if stripped_args == t.__args__:
- return t
- return t.copy_with(stripped_args)
- if isinstance(t, GenericAlias):
- stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
- if stripped_args == t.__args__:
- return t
- return GenericAlias(t.__origin__, stripped_args)
- return t
-
-
-def get_origin(tp):
- """Get the unsubscripted version of a type.
-
- This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar
- and Annotated. Return None for unsupported types. Examples::
-
- get_origin(Literal[42]) is Literal
- get_origin(int) is None
- get_origin(ClassVar[int]) is ClassVar
- get_origin(Generic) is Generic
- get_origin(Generic[T]) is Generic
- get_origin(Union[T, int]) is Union
- get_origin(List[Tuple[T, T]][int]) == list
- """
- if isinstance(tp, _AnnotatedAlias):
- return Annotated
- if isinstance(tp, (_BaseGenericAlias, GenericAlias)):
- return tp.__origin__
- if tp is Generic:
- return Generic
- return None
-
-
-def get_args(tp):
- """Get type arguments with all substitutions performed.
-
- For unions, basic simplifications used by Union constructor are performed.
- Examples::
- get_args(Dict[str, int]) == (str, int)
- get_args(int) == ()
- get_args(Union[int, Union[T, int], str][int]) == (int, str)
- get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
- get_args(Callable[[], T][int]) == ([], int)
- """
- if isinstance(tp, _AnnotatedAlias):
- return (tp.__origin__,) + tp.__metadata__
- if isinstance(tp, (_GenericAlias, GenericAlias)):
- res = tp.__args__
- if tp.__origin__ is collections.abc.Callable and res[0] is not Ellipsis:
- res = (list(res[:-1]), res[-1])
- return res
- return ()
-
-
-def no_type_check(arg):
- """Decorator to indicate that annotations are not type hints.
-
- The argument must be a class or function; if it is a class, it
- applies recursively to all methods and classes defined in that class
- (but not to methods defined in its superclasses or subclasses).
-
- This mutates the function(s) or class(es) in place.
- """
- if isinstance(arg, type):
- arg_attrs = arg.__dict__.copy()
- for attr, val in arg.__dict__.items():
- if val in arg.__bases__ + (arg,):
- arg_attrs.pop(attr)
- for obj in arg_attrs.values():
- if isinstance(obj, types.FunctionType):
- obj.__no_type_check__ = True
- if isinstance(obj, type):
- no_type_check(obj)
- try:
- arg.__no_type_check__ = True
- except TypeError: # built-in classes
- pass
- return arg
-
-
-def no_type_check_decorator(decorator):
- """Decorator to give another decorator the @no_type_check effect.
-
- This wraps the decorator with something that wraps the decorated
- function in @no_type_check.
- """
-
- @functools.wraps(decorator)
- def wrapped_decorator(*args, **kwds):
- func = decorator(*args, **kwds)
- func = no_type_check(func)
- return func
-
- return wrapped_decorator
-
-
-def _overload_dummy(*args, **kwds):
- """Helper for @overload to raise when called."""
- raise NotImplementedError(
- "You should not call an overloaded function. "
- "A series of @overload-decorated functions "
- "outside a stub module should always be followed "
- "by an implementation that is not @overload-ed.")
-
-
-def overload(func):
- """Decorator for overloaded functions/methods.
-
- In a stub file, place two or more stub definitions for the same
- function in a row, each decorated with @overload. For example:
-
- @overload
- def utf8(value: None) -> None: ...
- @overload
- def utf8(value: bytes) -> bytes: ...
- @overload
- def utf8(value: str) -> bytes: ...
-
- In a non-stub file (i.e. a regular .py file), do the same but
- follow it with an implementation. The implementation should *not*
- be decorated with @overload. For example:
-
- @overload
- def utf8(value: None) -> None: ...
- @overload
- def utf8(value: bytes) -> bytes: ...
- @overload
- def utf8(value: str) -> bytes: ...
- def utf8(value):
- # implementation goes here
- """
- return _overload_dummy
-
-
-def final(f):
- """A decorator to indicate final methods and final classes.
-
- Use this decorator to indicate to type checkers that the decorated
- method cannot be overridden, and decorated class cannot be subclassed.
- For example:
-
- class Base:
- @final
- def done(self) -> None:
- ...
- class Sub(Base):
- def done(self) -> None: # Error reported by type checker
- ...
-
- @final
- class Leaf:
- ...
- class Other(Leaf): # Error reported by type checker
- ...
-
- There is no runtime checking of these properties.
- """
- return f
-
-
-# Some unconstrained type variables. These are used by the container types.
-# (These are not for export.)
-T = TypeVar('T') # Any type.
-KT = TypeVar('KT') # Key type.
-VT = TypeVar('VT') # Value type.
-T_co = TypeVar('T_co', covariant=True) # Any type covariant containers.
-V_co = TypeVar('V_co', covariant=True) # Any type covariant containers.
-VT_co = TypeVar('VT_co', covariant=True) # Value type covariant containers.
-T_contra = TypeVar('T_contra', contravariant=True) # Ditto contravariant.
-# Internal type variable used for Type[].
-CT_co = TypeVar('CT_co', covariant=True, bound=type)
-
-# A useful type variable with constraints. This represents string types.
-# (This one *is* for export!)
-AnyStr = TypeVar('AnyStr', bytes, str)
-
-# Various ABCs mimicking those in collections.abc.
-_alias = _SpecialGenericAlias
-
-Hashable = _alias(collections.abc.Hashable, 0) # Not generic.
-Awaitable = _alias(collections.abc.Awaitable, 1)
-Coroutine = _alias(collections.abc.Coroutine, 3)
-AsyncIterable = _alias(collections.abc.AsyncIterable, 1)
-AsyncIterator = _alias(collections.abc.AsyncIterator, 1)
-Iterable = _alias(collections.abc.Iterable, 1)
-Iterator = _alias(collections.abc.Iterator, 1)
-Reversible = _alias(collections.abc.Reversible, 1)
-Sized = _alias(collections.abc.Sized, 0) # Not generic.
-Container = _alias(collections.abc.Container, 1)
-Collection = _alias(collections.abc.Collection, 1)
-Callable = _CallableType(collections.abc.Callable, 2)
-Callable.__doc__ = \
- """Callable type; Callable[[int], str] is a function of (int) -> str.
-
- The subscription syntax must always be used with exactly two
- values: the argument list and the return type. The argument list
- must be a list of types or ellipsis; the return type must be a single type.
-
- There is no syntax to indicate optional or keyword arguments,
- such function types are rarely used as callback types.
- """
-AbstractSet = _alias(collections.abc.Set, 1, name='AbstractSet')
-MutableSet = _alias(collections.abc.MutableSet, 1)
-# NOTE: Mapping is only covariant in the value type.
-Mapping = _alias(collections.abc.Mapping, 2)
-MutableMapping = _alias(collections.abc.MutableMapping, 2)
-Sequence = _alias(collections.abc.Sequence, 1)
-MutableSequence = _alias(collections.abc.MutableSequence, 1)
-ByteString = _alias(collections.abc.ByteString, 0) # Not generic
-# Tuple accepts variable number of parameters.
-Tuple = _TupleType(tuple, -1, inst=False, name='Tuple')
-Tuple.__doc__ = \
- """Tuple type; Tuple[X, Y] is the cross-product type of X and Y.
-
- Example: Tuple[T1, T2] is a tuple of two elements corresponding
- to type variables T1 and T2. Tuple[int, float, str] is a tuple
- of an int, a float and a string.
-
- To specify a variable-length tuple of homogeneous type, use Tuple[T, ...].
- """
-List = _alias(list, 1, inst=False, name='List')
-Deque = _alias(collections.deque, 1, name='Deque')
-Set = _alias(set, 1, inst=False, name='Set')
-FrozenSet = _alias(frozenset, 1, inst=False, name='FrozenSet')
-MappingView = _alias(collections.abc.MappingView, 1)
-KeysView = _alias(collections.abc.KeysView, 1)
-ItemsView = _alias(collections.abc.ItemsView, 2)
-ValuesView = _alias(collections.abc.ValuesView, 1)
-ContextManager = _alias(contextlib.AbstractContextManager, 1, name='ContextManager')
-AsyncContextManager = _alias(contextlib.AbstractAsyncContextManager, 1, name='AsyncContextManager')
-Dict = _alias(dict, 2, inst=False, name='Dict')
-DefaultDict = _alias(collections.defaultdict, 2, name='DefaultDict')
-OrderedDict = _alias(collections.OrderedDict, 2)
-Counter = _alias(collections.Counter, 1)
-ChainMap = _alias(collections.ChainMap, 2)
-Generator = _alias(collections.abc.Generator, 3)
-AsyncGenerator = _alias(collections.abc.AsyncGenerator, 2)
-Type = _alias(type, 1, inst=False, name='Type')
-Type.__doc__ = \
- """A special construct usable to annotate class objects.
-
- For example, suppose we have the following classes::
-
- class User: ... # Abstract base for User classes
- class BasicUser(User): ...
- class ProUser(User): ...
- class TeamUser(User): ...
-
- And a function that takes a class argument that's a subclass of
- User and returns an instance of the corresponding class::
-
- U = TypeVar('U', bound=User)
- def new_user(user_class: Type[U]) -> U:
- user = user_class()
- # (Here we could write the user object to a database)
- return user
-
- joe = new_user(BasicUser)
-
- At this point the type checker knows that joe has type BasicUser.
- """
-
-
-@runtime_checkable
-class SupportsInt(Protocol):
- """An ABC with one abstract method __int__."""
- __slots__ = ()
-
- @abstractmethod
- def __int__(self) -> int:
- pass
-
-
-@runtime_checkable
-class SupportsFloat(Protocol):
- """An ABC with one abstract method __float__."""
- __slots__ = ()
-
- @abstractmethod
- def __float__(self) -> float:
- pass
-
-
-@runtime_checkable
-class SupportsComplex(Protocol):
- """An ABC with one abstract method __complex__."""
- __slots__ = ()
-
- @abstractmethod
- def __complex__(self) -> complex:
- pass
-
-
-@runtime_checkable
-class SupportsBytes(Protocol):
- """An ABC with one abstract method __bytes__."""
- __slots__ = ()
-
- @abstractmethod
- def __bytes__(self) -> bytes:
- pass
-
-
-@runtime_checkable
-class SupportsIndex(Protocol):
- """An ABC with one abstract method __index__."""
- __slots__ = ()
-
- @abstractmethod
- def __index__(self) -> int:
- pass
-
-
-@runtime_checkable
-class SupportsAbs(Protocol[T_co]):
- """An ABC with one abstract method __abs__ that is covariant in its return type."""
- __slots__ = ()
-
- @abstractmethod
- def __abs__(self) -> T_co:
- pass
-
-
-@runtime_checkable
-class SupportsRound(Protocol[T_co]):
- """An ABC with one abstract method __round__ that is covariant in its return type."""
- __slots__ = ()
-
- @abstractmethod
- def __round__(self, ndigits: int = 0) -> T_co:
- pass
-
-
-def _make_nmtuple(name, types, module, defaults=()):
- fields = [n for n, t in types]
- types = {n: _type_check(t, f"field {n} annotation must be a type")
- for n, t in types}
- nm_tpl = collections.namedtuple(name, fields,
- defaults=defaults, module=module)
- nm_tpl.__annotations__ = nm_tpl.__new__.__annotations__ = types
- return nm_tpl
-
-
-# attributes prohibited to set in NamedTuple class syntax
-_prohibited = frozenset({'__new__', '__init__', '__slots__', '__getnewargs__',
- '_fields', '_field_defaults',
- '_make', '_replace', '_asdict', '_source'})
-
-_special = frozenset({'__module__', '__name__', '__annotations__'})
-
-
-class NamedTupleMeta(type):
-
- def __new__(cls, typename, bases, ns):
- assert bases[0] is _NamedTuple
- types = ns.get('__annotations__', {})
- default_names = []
- for field_name in types:
- if field_name in ns:
- default_names.append(field_name)
- elif default_names:
- raise TypeError(f"Non-default namedtuple field {field_name} "
- f"cannot follow default field"
- f"{'s' if len(default_names) > 1 else ''} "
- f"{', '.join(default_names)}")
- nm_tpl = _make_nmtuple(typename, types.items(),
- defaults=[ns[n] for n in default_names],
- module=ns['__module__'])
- # update from user namespace without overriding special namedtuple attributes
- for key in ns:
- if key in _prohibited:
- raise AttributeError("Cannot overwrite NamedTuple attribute " + key)
- elif key not in _special and key not in nm_tpl._fields:
- setattr(nm_tpl, key, ns[key])
- return nm_tpl
-
-
-def NamedTuple(typename, fields=None, /, **kwargs):
- """Typed version of namedtuple.
-
- Usage in Python versions >= 3.6::
-
- class Employee(NamedTuple):
- name: str
- id: int
-
- This is equivalent to::
-
- Employee = collections.namedtuple('Employee', ['name', 'id'])
-
- The resulting class has an extra __annotations__ attribute, giving a
- dict that maps field names to types. (The field names are also in
- the _fields attribute, which is part of the namedtuple API.)
- Alternative equivalent keyword syntax is also accepted::
-
- Employee = NamedTuple('Employee', name=str, id=int)
-
- In Python versions <= 3.5 use::
-
- Employee = NamedTuple('Employee', [('name', str), ('id', int)])
- """
- if fields is None:
- fields = kwargs.items()
- elif kwargs:
- raise TypeError("Either list of fields or keywords"
- " can be provided to NamedTuple, not both")
- try:
- module = sys._getframe(1).f_globals.get('__name__', '__main__')
- except (AttributeError, ValueError):
- module = None
- return _make_nmtuple(typename, fields, module=module)
-
-
-_NamedTuple = type.__new__(NamedTupleMeta, 'NamedTuple', (), {})
-
-
-def _namedtuple_mro_entries(bases):
- if len(bases) > 1:
- raise TypeError("Multiple inheritance with NamedTuple is not supported")
- assert bases[0] is NamedTuple
- return (_NamedTuple,)
-
-
-NamedTuple.__mro_entries__ = _namedtuple_mro_entries
-
-
-class _TypedDictMeta(type):
- def __new__(cls, name, bases, ns, total=True):
- """Create new typed dict class object.
-
- This method is called when TypedDict is subclassed,
- or when TypedDict is instantiated. This way
- TypedDict supports all three syntax forms described in its docstring.
- Subclasses and instances of TypedDict return actual dictionaries.
- """
- for base in bases:
- if type(base) is not _TypedDictMeta:
- raise TypeError('cannot inherit from both a TypedDict type '
- 'and a non-TypedDict base class')
- tp_dict = type.__new__(_TypedDictMeta, name, (dict,), ns)
-
- annotations = {}
- own_annotations = ns.get('__annotations__', {})
- own_annotation_keys = set(own_annotations.keys())
- msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
- own_annotations = {
- n: _type_check(tp, msg, module=tp_dict.__module__)
- for n, tp in own_annotations.items()
- }
- required_keys = set()
- optional_keys = set()
-
- for base in bases:
- annotations.update(base.__dict__.get('__annotations__', {}))
- required_keys.update(base.__dict__.get('__required_keys__', ()))
- optional_keys.update(base.__dict__.get('__optional_keys__', ()))
-
- annotations.update(own_annotations)
- if total:
- required_keys.update(own_annotation_keys)
- else:
- optional_keys.update(own_annotation_keys)
-
- tp_dict.__annotations__ = annotations
- tp_dict.__required_keys__ = frozenset(required_keys)
- tp_dict.__optional_keys__ = frozenset(optional_keys)
- if not hasattr(tp_dict, '__total__'):
- tp_dict.__total__ = total
- return tp_dict
-
- __call__ = dict # static method
-
- def __subclasscheck__(cls, other):
- # Typed dicts are only for static structural subtyping.
- raise TypeError('TypedDict does not support instance and class checks')
-
- __instancecheck__ = __subclasscheck__
-
-
-def TypedDict(typename, fields=None, /, *, total=True, **kwargs):
- """A simple typed namespace. At runtime it is equivalent to a plain dict.
-
- TypedDict creates a dictionary type that expects all of its
- instances to have a certain set of keys, where each key is
- associated with a value of a consistent type. This expectation
- is not checked at runtime but is only enforced by type checkers.
- Usage::
-
- class Point2D(TypedDict):
- x: int
- y: int
- label: str
-
- a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK
- b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check
-
- assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first')
-
- The type info can be accessed via the Point2D.__annotations__ dict, and
- the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets.
- TypedDict supports two additional equivalent forms::
-
- Point2D = TypedDict('Point2D', x=int, y=int, label=str)
- Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str})
-
- By default, all keys must be present in a TypedDict. It is possible
- to override this by specifying totality.
- Usage::
-
- class point2D(TypedDict, total=False):
- x: int
- y: int
-
- This means that a point2D TypedDict can have any of the keys omitted.A type
- checker is only expected to support a literal False or True as the value of
- the total argument. True is the default, and makes all items defined in the
- class body be required.
-
- The class syntax is only supported in Python 3.6+, while two other
- syntax forms work for Python 2.7 and 3.2+
- """
- if fields is None:
- fields = kwargs
- elif kwargs:
- raise TypeError("TypedDict takes either a dict or keyword arguments,"
- " but not both")
-
- ns = {'__annotations__': dict(fields)}
- try:
- # Setting correct module is necessary to make typed dict classes pickleable.
- ns['__module__'] = sys._getframe(1).f_globals.get('__name__', '__main__')
- except (AttributeError, ValueError):
- pass
-
- return _TypedDictMeta(typename, (), ns, total=total)
-
-
-_TypedDict = type.__new__(_TypedDictMeta, 'TypedDict', (), {})
-TypedDict.__mro_entries__ = lambda bases: (_TypedDict,)
-
-
-def NewType(name, tp):
- """NewType creates simple unique types with almost zero
- runtime overhead. NewType(name, tp) is considered a subtype of tp
- by static type checkers. At runtime, NewType(name, tp) returns
- a dummy function that simply returns its argument. Usage::
-
- UserId = NewType('UserId', int)
-
- def name_by_id(user_id: UserId) -> str:
- ...
-
- UserId('user') # Fails type check
-
- name_by_id(42) # Fails type check
- name_by_id(UserId(42)) # OK
-
- num = UserId(5) + 1 # type: int
- """
-
- def new_type(x):
- return x
-
- new_type.__name__ = name
- new_type.__supertype__ = tp
- return new_type
-
-
-# Python-version-specific alias (Python 2: unicode; Python 3: str)
-Text = str
-
-# Constant that's True when type checking, but False here.
-TYPE_CHECKING = False
-
-
-class IO(Generic[AnyStr]):
- """Generic base class for TextIO and BinaryIO.
-
- This is an abstract, generic version of the return of open().
-
- NOTE: This does not distinguish between the different possible
- classes (text vs. binary, read vs. write vs. read/write,
- append-only, unbuffered). The TextIO and BinaryIO subclasses
- below capture the distinctions between text vs. binary, which is
- pervasive in the interface; however we currently do not offer a
- way to track the other distinctions in the type system.
- """
-
- __slots__ = ()
-
- @property
- @abstractmethod
- def mode(self) -> str:
- pass
-
- @property
- @abstractmethod
- def name(self) -> str:
- pass
-
- @abstractmethod
- def close(self) -> None:
- pass
-
- @property
- @abstractmethod
- def closed(self) -> bool:
- pass
-
- @abstractmethod
- def fileno(self) -> int:
- pass
-
- @abstractmethod
- def flush(self) -> None:
- pass
-
- @abstractmethod
- def isatty(self) -> bool:
- pass
-
- @abstractmethod
- def read(self, n: int = -1) -> AnyStr:
- pass
-
- @abstractmethod
- def readable(self) -> bool:
- pass
-
- @abstractmethod
- def readline(self, limit: int = -1) -> AnyStr:
- pass
-
- @abstractmethod
- def readlines(self, hint: int = -1) -> List[AnyStr]:
- pass
-
- @abstractmethod
- def seek(self, offset: int, whence: int = 0) -> int:
- pass
-
- @abstractmethod
- def seekable(self) -> bool:
- pass
-
- @abstractmethod
- def tell(self) -> int:
- pass
-
- @abstractmethod
- def truncate(self, size: int = None) -> int:
- pass
-
- @abstractmethod
- def writable(self) -> bool:
- pass
-
- @abstractmethod
- def write(self, s: AnyStr) -> int:
- pass
-
- @abstractmethod
- def writelines(self, lines: List[AnyStr]) -> None:
- pass
-
- @abstractmethod
- def __enter__(self) -> 'IO[AnyStr]':
- pass
-
- @abstractmethod
- def __exit__(self, type, value, traceback) -> None:
- pass
-
-
-class BinaryIO(IO[bytes]):
- """Typed version of the return of open() in binary mode."""
-
- __slots__ = ()
-
- @abstractmethod
- def write(self, s: Union[bytes, bytearray]) -> int:
- pass
-
- @abstractmethod
- def __enter__(self) -> 'BinaryIO':
- pass
-
-
-class TextIO(IO[str]):
- """Typed version of the return of open() in text mode."""
-
- __slots__ = ()
-
- @property
- @abstractmethod
- def buffer(self) -> BinaryIO:
- pass
-
- @property
- @abstractmethod
- def encoding(self) -> str:
- pass
-
- @property
- @abstractmethod
- def errors(self) -> Optional[str]:
- pass
-
- @property
- @abstractmethod
- def line_buffering(self) -> bool:
- pass
-
- @property
- @abstractmethod
- def newlines(self) -> Any:
- pass
-
- @abstractmethod
- def __enter__(self) -> 'TextIO':
- pass
-
-
-class io:
- """Wrapper namespace for IO generic classes."""
-
- __all__ = ['IO', 'TextIO', 'BinaryIO']
- IO = IO
- TextIO = TextIO
- BinaryIO = BinaryIO
-
-
-io.__name__ = __name__ + '.io'
-sys.modules[io.__name__] = io
-
-Pattern = _alias(stdlib_re.Pattern, 1)
-Match = _alias(stdlib_re.Match, 1)
-
-
-class re:
- """Wrapper namespace for re type aliases."""
-
- __all__ = ['Pattern', 'Match']
- Pattern = Pattern
- Match = Match
-
-
-re.__name__ = __name__ + '.re'
-sys.modules[re.__name__] = re
diff --git a/brainpy/dnn/others.py b/brainpy/dnn/others.py
index 46f771a63..be4a8f846 100644
--- a/brainpy/dnn/others.py
+++ b/brainpy/dnn/others.py
@@ -1,9 +1,5 @@
-from brainpy._src.dnn.base import (
- Layer as Layer,
-)
-
from brainpy._src.dnn.dropout import (
Dropout as Dropout,
)
diff --git a/brainpy/neurons.py b/brainpy/neurons.py
index e045035a1..9f41ae089 100644
--- a/brainpy/neurons.py
+++ b/brainpy/neurons.py
@@ -27,3 +27,12 @@
FHN as FHN,
LIF_SFA_Bellec2020,
)
+from .dyn.others import (
+ InputGroup as InputGroup,
+ OutputGroup as OutputGroup,
+ SpikeTimeGroup as SpikeTimeGroup,
+ PoissonGroup as PoissonGroup,
+ Leaky as Leaky,
+ Integrator as Integrator,
+ OUProcess as OUProcess,
+)
diff --git a/tests/simulation/test_neu_HH.py b/tests/simulation/test_neu_HH.py
index 0990733a4..2e80cabb5 100644
--- a/tests/simulation/test_neu_HH.py
+++ b/tests/simulation/test_neu_HH.py
@@ -12,7 +12,7 @@ def __init__(self, size):
self.IL = bp.channels.IL(size, E=-54.387, g_max=0.03)
-class HHv2(bp.NeuDyn):
+class HHv2(bp.dyn.NeuDyn):
def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03, V_th=20., C=1.0):
super().__init__(size=size)
From 1830cffcb5728841c9466eda85ecedd145c975f3 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sun, 9 Jul 2023 16:08:31 +0800
Subject: [PATCH 018/326] updates and fixes
---
brainpy/__init__.py | 3 +-
brainpy/_add_deprecations.py | 16 ++++++-
brainpy/_src/dnn/activations.py | 56 ++++++++++++-------------
brainpy/_src/dnn/conv.py | 6 +--
brainpy/_src/dnn/dropout.py | 4 +-
brainpy/_src/dnn/function.py | 8 ++--
brainpy/_src/dnn/interoperation_flax.py | 4 +-
brainpy/_src/dnn/linear.py | 34 +++++++--------
brainpy/_src/dnn/normalization.py | 8 ++--
brainpy/_src/dnn/nvar.py | 4 +-
brainpy/_src/dnn/pooling.py | 8 ++--
brainpy/_src/dnn/reservoir.py | 4 +-
brainpy/_src/dnn/rnncells.py | 10 ++---
brainpy/_src/dyn/base.py | 8 ++--
brainpy/_src/dyn/projections/aligns.py | 10 ++---
brainpy/_src/dynsys.py | 34 ++++++++-------
brainpy/_src/layer.py | 8 ----
brainpy/_src/losses/base.py | 4 +-
brainpy/_src/mixin.py | 2 +-
brainpy/_src/tests/test_mixin.py | 2 +-
brainpy/dyn/base.py | 2 +-
21 files changed, 123 insertions(+), 112 deletions(-)
delete mode 100644 brainpy/_src/layer.py
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index efb4af83d..89e407e5e 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -59,8 +59,9 @@
DynSysGroup as DynSysGroup, # collectors
Sequential as Sequential,
Network as Network,
- Dynamics as Dynamics, # category
+ Dynamic as Dynamic, # category
Projection as Projection,
+ AnnLayer as Layer,
)
DynamicalSystemNS = DynamicalSystem
diff --git a/brainpy/_add_deprecations.py b/brainpy/_add_deprecations.py
index b7a477ae3..77208381c 100644
--- a/brainpy/_add_deprecations.py
+++ b/brainpy/_add_deprecations.py
@@ -1,6 +1,6 @@
from ._src import checking, train, integrators
-from . import tools, math, integrators, dyn, neurons, synapses
+from . import tools, math, integrators, dyn, dnn, neurons, synapses, layers
from .integrators import ode, fde, sde
from brainpy._src.integrators.base import Integrator
from brainpy._src.integrators.runner import IntegratorRunner
@@ -8,7 +8,7 @@
from brainpy._src.integrators.ode.generic import odeint
from brainpy._src.integrators.sde.generic import sdeint
from brainpy._src.integrators.fde.generic import fdeint
-from brainpy._src.dynsys import (DynamicalSystem, DynSysGroup, Sequential, Network)
+from brainpy._src.dynsys import (DynamicalSystem, DynSysGroup, Sequential, Network, AnnLayer)
from brainpy._src.dyn.base import NeuDyn, IonChaDyn
from brainpy._src.runners import DSRunner
from brainpy._src.deprecations import deprecation_getattr2
@@ -102,3 +102,15 @@
dyn.__getattr__ = deprecation_getattr2('brainpy.dyn', dyn.__deprecations)
+dnn.__deprecations = {
+ 'Layer': ('brainpy.dnn.Layer', 'brainpy.AnnLayer', AnnLayer),
+}
+dnn.__getattr__ = deprecation_getattr2('brainpy.dnn', dnn.__deprecations)
+
+
+layers.__deprecations = {
+ 'Layer': ('brainpy.layers.Layer', 'brainpy.AnnLayer', AnnLayer),
+}
+layers.__getattr__ = deprecation_getattr2('brainpy.layers', layers.__deprecations)
+
+
diff --git a/brainpy/_src/dnn/activations.py b/brainpy/_src/dnn/activations.py
index a1bef95e0..d079e4421 100644
--- a/brainpy/_src/dnn/activations.py
+++ b/brainpy/_src/dnn/activations.py
@@ -1,8 +1,8 @@
from typing import Optional
from brainpy import math as bm
+from brainpy._src.dynsys import AnnLayer
from brainpy.types import ArrayType
-from brainpy._src.layer import Layer
__all__ = [
'Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh',
@@ -21,7 +21,7 @@ def _inplace(inp, val, inplace):
return val
-class Threshold(Layer):
+class Threshold(AnnLayer):
r"""Thresholds each element of the input Tensor.
Threshold is defined as:
@@ -73,7 +73,7 @@ def extra_repr(self):
)
-class ReLU(Layer):
+class ReLU(AnnLayer):
r"""Applies the rectified linear unit function element-wise:
:math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
@@ -118,7 +118,7 @@ def extra_repr(self) -> str:
return inplace_str
-class RReLU(Layer):
+class RReLU(AnnLayer):
r"""Applies the randomized leaky rectified liner unit function, element-wise,
as described in the paper:
@@ -184,7 +184,7 @@ def extra_repr(self):
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
-class Hardtanh(Layer):
+class Hardtanh(AnnLayer):
r"""Applies the HardTanh function element-wise.
HardTanh is defined as:
@@ -275,7 +275,7 @@ def extra_repr(self) -> str:
return inplace_str
-class Sigmoid(Layer):
+class Sigmoid(AnnLayer):
r"""Applies the element-wise function:
.. math::
@@ -299,7 +299,7 @@ def update(self, input: ArrayType) -> ArrayType:
return bm.sigmoid(input)
-class Hardsigmoid(Layer):
+class Hardsigmoid(AnnLayer):
r"""Applies the Hardsigmoid function element-wise.
Hardsigmoid is defined as:
@@ -339,7 +339,7 @@ def update(self, input: ArrayType) -> ArrayType:
return _inplace(input, x, self.inplace)
-class Tanh(Layer):
+class Tanh(AnnLayer):
r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
Tanh is defined as:
@@ -364,7 +364,7 @@ def update(self, input: ArrayType) -> ArrayType:
return bm.tanh(input)
-class SiLU(Layer):
+class SiLU(AnnLayer):
r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
The SiLU function is also known as the swish function.
@@ -406,7 +406,7 @@ def extra_repr(self) -> str:
return inplace_str
-class Mish(Layer):
+class Mish(AnnLayer):
r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
@@ -443,7 +443,7 @@ def extra_repr(self) -> str:
return inplace_str
-class Hardswish(Layer):
+class Hardswish(AnnLayer):
r"""Applies the Hardswish function, element-wise, as described in the paper:
`Searching for MobileNetV3 `_.
@@ -483,7 +483,7 @@ def update(self, input: ArrayType) -> ArrayType:
return _inplace(input, bm.hard_swish(input), self.inplace)
-class ELU(Layer):
+class ELU(AnnLayer):
r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described
in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
Units (ELUs) `__.
@@ -529,7 +529,7 @@ def extra_repr(self) -> str:
return 'alpha={}{}'.format(self.alpha, inplace_str)
-class CELU(Layer):
+class CELU(AnnLayer):
r"""Applies the element-wise function:
.. math::
@@ -573,7 +573,7 @@ def extra_repr(self) -> str:
return 'alpha={}{}'.format(self.alpha, inplace_str)
-class SELU(Layer):
+class SELU(AnnLayer):
r"""Applied element-wise, as:
.. math::
@@ -616,7 +616,7 @@ def extra_repr(self) -> str:
return inplace_str
-class GLU(Layer):
+class GLU(AnnLayer):
r"""Applies the gated linear unit function
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
of the input matrices and :math:`b` is the second half.
@@ -651,7 +651,7 @@ def extra_repr(self) -> str:
return 'dim={}'.format(self.dim)
-class GELU(Layer):
+class GELU(AnnLayer):
r"""Applies the Gaussian Error Linear Units function:
.. math:: \text{GELU}(x) = x * \Phi(x)
@@ -692,7 +692,7 @@ def extra_repr(self) -> str:
return 'approximate={}'.format(repr(self.approximate))
-class Hardshrink(Layer):
+class Hardshrink(AnnLayer):
r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
Hardshrink is defined as:
@@ -734,7 +734,7 @@ def extra_repr(self) -> str:
return '{}'.format(self.lambd)
-class LeakyReLU(Layer):
+class LeakyReLU(AnnLayer):
r"""Applies the element-wise function:
.. math::
@@ -785,7 +785,7 @@ def extra_repr(self) -> str:
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
-class LogSigmoid(Layer):
+class LogSigmoid(AnnLayer):
r"""Applies the element-wise function:
.. math::
@@ -808,7 +808,7 @@ def update(self, input: ArrayType) -> ArrayType:
return bm.log_sigmoid(input)
-class Softplus(Layer):
+class Softplus(AnnLayer):
r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
\log(1 + \exp(\beta * x))` element-wise.
@@ -850,7 +850,7 @@ def extra_repr(self) -> str:
return 'beta={}, threshold={}'.format(self.beta, self.threshold)
-class Softshrink(Layer):
+class Softshrink(AnnLayer):
r"""Applies the soft shrinkage function elementwise:
.. math::
@@ -890,7 +890,7 @@ def extra_repr(self) -> str:
return str(self.lambd)
-class PReLU(Layer):
+class PReLU(AnnLayer):
r"""Applies the element-wise function:
.. math::
@@ -954,7 +954,7 @@ def extra_repr(self) -> str:
return 'num_parameters={}'.format(self.num_parameters)
-class Softsign(Layer):
+class Softsign(AnnLayer):
r"""Applies the element-wise function:
.. math::
@@ -977,7 +977,7 @@ def update(self, input: ArrayType) -> ArrayType:
return bm.soft_sign(input)
-class Tanhshrink(Layer):
+class Tanhshrink(AnnLayer):
r"""Applies the element-wise function:
.. math::
@@ -1000,7 +1000,7 @@ def update(self, input: ArrayType) -> ArrayType:
return bm.tanh_shrink(input)
-class Softmin(Layer):
+class Softmin(AnnLayer):
r"""Applies the Softmin function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
lie in the range `[0, 1]` and sum to 1.
@@ -1045,7 +1045,7 @@ def extra_repr(self):
return 'dim={dim}'.format(dim=self.dim)
-class Softmax(Layer):
+class Softmax(AnnLayer):
r"""Applies the Softmax function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
lie in the range [0,1] and sum to 1.
@@ -1099,7 +1099,7 @@ def extra_repr(self) -> str:
return 'dim={dim}'.format(dim=self.dim)
-class Softmax2d(Layer):
+class Softmax2d(AnnLayer):
r"""Applies SoftMax over features to each spatial location.
When given an image of ``Channels x Height x Width``, it will
@@ -1128,7 +1128,7 @@ def update(self, input: ArrayType) -> ArrayType:
return bm.softmax(input, -3)
-class LogSoftmax(Layer):
+class LogSoftmax(AnnLayer):
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
input Tensor. The LogSoftmax formulation can be simplified as:
diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py
index f5e4a1e60..daf85ad74 100644
--- a/brainpy/_src/dnn/conv.py
+++ b/brainpy/_src/dnn/conv.py
@@ -7,7 +7,7 @@
from brainpy import math as bm, tools
from brainpy._src.initialize import Initializer, XavierNormal, ZeroInit, parameter
from brainpy.types import ArrayType
-from brainpy._src.layer import Layer
+from brainpy._src.dynsys import AnnLayer
__all__ = [
'Conv1d', 'Conv2d', 'Conv3d',
@@ -36,7 +36,7 @@ def to_dimension_numbers(num_spatial_dims: int,
out_spec=image_dn)
-class _GeneralConv(Layer):
+class _GeneralConv(AnnLayer):
"""Apply a convolution to the inputs.
Parameters
@@ -462,7 +462,7 @@ def _check_input_dim(self, x):
Conv3D = Conv3d
-class _GeneralConvTranspose(Layer):
+class _GeneralConvTranspose(AnnLayer):
supported_modes = (bm.TrainingMode, bm.BatchingMode)
def __init__(
diff --git a/brainpy/_src/dnn/dropout.py b/brainpy/_src/dnn/dropout.py
index 184a46aa5..c5583b67f 100644
--- a/brainpy/_src/dnn/dropout.py
+++ b/brainpy/_src/dnn/dropout.py
@@ -4,14 +4,14 @@
from brainpy._src.context import share
from brainpy import math as bm, check
-from brainpy._src.layer import Layer
+from brainpy._src.dynsys import AnnLayer
__all__ = [
'Dropout'
]
-class Dropout(Layer):
+class Dropout(AnnLayer):
"""A layer that stochastically ignores a subset of inputs each training step.
In training, to compensate for the fraction of input values dropped (`rate`),
diff --git a/brainpy/_src/dnn/function.py b/brainpy/_src/dnn/function.py
index 7d12246b4..0223a387a 100644
--- a/brainpy/_src/dnn/function.py
+++ b/brainpy/_src/dnn/function.py
@@ -5,7 +5,7 @@
import brainpy.math as bm
from brainpy import check
-from brainpy._src.layer import Layer
+from brainpy._src.dynsys import AnnLayer
__all__ = [
'Activation',
@@ -14,7 +14,7 @@
]
-class Activation(Layer):
+class Activation(AnnLayer):
r"""Applies an activation function to the inputs
Parameters:
@@ -43,7 +43,7 @@ def update(self, *args, **kwargs):
return self.activate_fun(*args, **kwargs, **self.kwargs)
-class Flatten(Layer):
+class Flatten(AnnLayer):
r"""Flattens a contiguous range of dims into 2D or 1D.
Parameters:
@@ -69,7 +69,7 @@ def update(self, x):
return x.flatten()
-class FunAsLayer(Layer):
+class FunAsLayer(AnnLayer):
def __init__(
self,
fun: Callable,
diff --git a/brainpy/_src/dnn/interoperation_flax.py b/brainpy/_src/dnn/interoperation_flax.py
index ce98964fc..5765df8fa 100644
--- a/brainpy/_src/dnn/interoperation_flax.py
+++ b/brainpy/_src/dnn/interoperation_flax.py
@@ -7,7 +7,7 @@
from brainpy import math as bm
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.context import share
-from brainpy._src.layer import Layer
+from brainpy._src.dynsys import AnnLayer
try:
import flax # noqa
@@ -35,7 +35,7 @@ def _is_bp(a):
return isinstance(a, bm.Array)
-class FromFlax(Layer):
+class FromFlax(AnnLayer):
"""
Transform a Flax module as a BrainPy :py:class:`~.DynamicalSystem`.
diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py
index b4f638fca..a34f148c2 100644
--- a/brainpy/_src/dnn/linear.py
+++ b/brainpy/_src/dnn/linear.py
@@ -14,7 +14,7 @@
from brainpy.errors import MathError
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.types import ArrayType, Sharding
-from brainpy._src.layer import Layer
+from brainpy._src.dynsys import AnnLayer
__all__ = [
'Dense', 'Linear',
@@ -28,7 +28,7 @@
]
-class Dense(Layer):
+class Dense(AnnLayer):
r"""A linear transformation applied over the last dimension of the input.
Mathematically, this node can be defined as:
@@ -207,7 +207,7 @@ def offline_fit(self,
Linear = Dense
-class Identity(Layer):
+class Identity(AnnLayer):
r"""A placeholder identity operator that is argument-insensitive.
"""
@@ -218,7 +218,7 @@ def update(self, x):
return x
-class AllToAll(Layer):
+class AllToAll(AnnLayer):
"""Synaptic matrix multiplication with All2All connections.
Args:
@@ -281,7 +281,7 @@ def update(self, pre_val):
return post_val
-class OneToOne(Layer):
+class OneToOne(AnnLayer):
"""Synaptic matrix multiplication with One2One connection.
Args:
@@ -315,7 +315,7 @@ def update(self, pre_val):
return pre_val * self.weight
-class MaskedLinear(Layer):
+class MaskedLinear(AnnLayer):
r"""Synaptic matrix multiplication with masked dense computation.
It performs the computation of:
@@ -366,7 +366,7 @@ def update(self, x):
return x @ (self.weight * self.mask)
-class CSRLinear(Layer):
+class CSRLinear(AnnLayer):
r"""Synaptic matrix multiplication with CSR sparse computation.
It performs the computation of:
@@ -435,7 +435,7 @@ def _batch_csrmv(self, x):
method=self.method)
-class CSCLinear(Layer):
+class CSCLinear(AnnLayer):
r"""Synaptic matrix multiplication with CSC sparse computation.
It performs the computation of:
@@ -470,7 +470,7 @@ def __init__(
self.sharding = sharding
-class EventCSRLinear(Layer):
+class EventCSRLinear(AnnLayer):
r"""Synaptic matrix multiplication with event CSR sparse computation.
It performs the computation of:
@@ -535,7 +535,7 @@ def _batch_csrmv(self, x):
transpose=self.transpose)
-class BcsrMM(Layer):
+class BcsrMM(AnnLayer):
r"""Synaptic matrix multiplication with BCSR sparse computation.
It performs the computation of:
@@ -570,7 +570,7 @@ def __init__(
self.sharding = sharding
-class BcscMM(Layer):
+class BcscMM(AnnLayer):
r"""Synaptic matrix multiplication with BCSC sparse computation.
It performs the computation of:
@@ -605,7 +605,7 @@ def __init__(
self.sharding = sharding
-class JitFPHomoLinear(Layer):
+class JitFPHomoLinear(AnnLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
@@ -684,7 +684,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)
-class JitFPUniformLinear(Layer):
+class JitFPUniformLinear(AnnLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
@@ -764,7 +764,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)
-class JitFPNormalLinear(Layer):
+class JitFPNormalLinear(AnnLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
@@ -844,7 +844,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)
-class EventJitFPHomoLinear(Layer):
+class EventJitFPHomoLinear(AnnLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
@@ -923,7 +923,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)
-class EventJitFPUniformLinear(Layer):
+class EventJitFPUniformLinear(AnnLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
@@ -1003,7 +1003,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)
-class EventJitFPNormalLinear(Layer):
+class EventJitFPNormalLinear(AnnLayer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
diff --git a/brainpy/_src/dnn/normalization.py b/brainpy/_src/dnn/normalization.py
index e99e162c3..dad6dd841 100644
--- a/brainpy/_src/dnn/normalization.py
+++ b/brainpy/_src/dnn/normalization.py
@@ -8,7 +8,7 @@
from brainpy import math as bm, check
from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter
from brainpy.types import ArrayType
-from brainpy._src.layer import Layer
+from brainpy._src.dynsys import AnnLayer
__all__ = [
'BatchNorm1d',
@@ -32,7 +32,7 @@ def _square(x):
return lax.square(x)
-class BatchNorm(Layer):
+class BatchNorm(AnnLayer):
r"""Batch Normalization layer [1]_.
This layer aims to reduce the internal covariant shift of data. It
@@ -407,7 +407,7 @@ def _check_input_dim(self, x):
assert x.shape[-1] == self.num_features
-class LayerNorm(Layer):
+class LayerNorm(AnnLayer):
r"""Layer normalization (https://arxiv.org/abs/1607.06450).
.. math::
@@ -504,7 +504,7 @@ def update(self, x):
return out
-class GroupNorm(Layer):
+class GroupNorm(AnnLayer):
r"""Group normalization layer.
.. math::
diff --git a/brainpy/_src/dnn/nvar.py b/brainpy/_src/dnn/nvar.py
index da1f6ed48..87029a45b 100644
--- a/brainpy/_src/dnn/nvar.py
+++ b/brainpy/_src/dnn/nvar.py
@@ -8,7 +8,7 @@
import brainpy.math as bm
from brainpy import check
-from brainpy._src.layer import Layer
+from brainpy._src.dynsys import AnnLayer
__all__ = [
'NVAR'
@@ -34,7 +34,7 @@ def _comb(N, k):
return 0
-class NVAR(Layer):
+class NVAR(AnnLayer):
"""Nonlinear vector auto-regression (NVAR) node.
This class has the following features:
diff --git a/brainpy/_src/dnn/pooling.py b/brainpy/_src/dnn/pooling.py
index 3bb38ff3b..148e8537e 100644
--- a/brainpy/_src/dnn/pooling.py
+++ b/brainpy/_src/dnn/pooling.py
@@ -7,7 +7,7 @@
import numpy as np
from brainpy import math as bm, check
-from brainpy._src.layer import Layer
+from brainpy._src.dynsys import AnnLayer
__all__ = [
'MaxPool',
@@ -28,7 +28,7 @@
]
-class Pool(Layer):
+class Pool(AnnLayer):
"""Pooling functions are implemented using the ReduceWindow XLA op.
Parameters
@@ -285,7 +285,7 @@ def update(self, x):
return pooled / window_counts
-class _MaxPoolNd(Layer):
+class _MaxPoolNd(AnnLayer):
def __init__(
self,
init_value,
@@ -717,7 +717,7 @@ def _generate_vmap(fun: Callable, map_axes: List[int]):
return fun
-class AdaptivePool(Layer):
+class AdaptivePool(AnnLayer):
"""General N dimensional adaptive down-sampling to a target shape.
Parameters
diff --git a/brainpy/_src/dnn/reservoir.py b/brainpy/_src/dnn/reservoir.py
index c5ea3cb5a..e21605ac2 100644
--- a/brainpy/_src/dnn/reservoir.py
+++ b/brainpy/_src/dnn/reservoir.py
@@ -9,14 +9,14 @@
from brainpy import check
from brainpy.tools import to_size
from brainpy.types import ArrayType
-from brainpy._src.layer import Layer
+from brainpy._src.dynsys import AnnLayer
__all__ = [
'Reservoir',
]
-class Reservoir(Layer):
+class Reservoir(AnnLayer):
r"""Reservoir node, a pool of leaky-integrator neurons
with random recurrent connections [1]_.
diff --git a/brainpy/_src/dnn/rnncells.py b/brainpy/_src/dnn/rnncells.py
index 2df1b4a76..0038e2d29 100644
--- a/brainpy/_src/dnn/rnncells.py
+++ b/brainpy/_src/dnn/rnncells.py
@@ -7,7 +7,7 @@
import brainpy.math as bm
from brainpy.math import activations
-from brainpy._src.layer import Layer
+from brainpy._src.dynsys import AnnLayer
from brainpy.check import (is_integer,
is_initializer)
from brainpy.initialize import (XavierNormal,
@@ -27,7 +27,7 @@
]
-class RNNCell(Layer):
+class RNNCell(AnnLayer):
r"""Basic fully-connected RNN core.
Given :math:`x_t` and the previous hidden state :math:`h_{t-1}` the
@@ -125,7 +125,7 @@ def update(self, x):
return self.state.value
-class GRUCell(Layer):
+class GRUCell(AnnLayer):
r"""Gated Recurrent Unit.
The implementation is based on (Chung, et al., 2014) [1]_ with biases.
@@ -247,7 +247,7 @@ def update(self, x):
return self.state.value
-class LSTMCell(Layer):
+class LSTMCell(AnnLayer):
r"""Long short-term memory (LSTM) RNN core.
The implementation is based on (zaremba, et al., 2014) [1]_. Given
@@ -442,7 +442,7 @@ def __init__(self, *args, **kwargs):
super(LSTM, self).__init__(*args, **kwargs)
-class _ConvNDLSTMCell(Layer):
+class _ConvNDLSTMCell(AnnLayer):
r"""``num_spatial_dims``-D convolutional LSTM.
The implementation is based on :cite:`xingjian2015convolutional`.
diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py
index c37504d47..e318eee4b 100644
--- a/brainpy/_src/dyn/base.py
+++ b/brainpy/_src/dyn/base.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-from brainpy._src.dynsys import Dynamics
+from brainpy._src.dynsys import Dynamic
from brainpy._src.mixin import AutoDelaySupp, ParamDesc
__all__ = [
@@ -8,17 +8,17 @@
]
-class NeuDyn(Dynamics, AutoDelaySupp):
+class NeuDyn(Dynamic, AutoDelaySupp):
"""Neuronal Dynamics."""
pass
-class SynDyn(Dynamics, AutoDelaySupp, ParamDesc):
+class SynDyn(Dynamic, AutoDelaySupp, ParamDesc):
"""Synaptic Dynamics."""
pass
-class IonChaDyn(Dynamics):
+class IonChaDyn(Dynamic):
"""Ion Channel Dynamics."""
pass
diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py
index 7d0f7395b..925d7dd22 100644
--- a/brainpy/_src/dyn/projections/aligns.py
+++ b/brainpy/_src/dyn/projections/aligns.py
@@ -2,7 +2,7 @@
from brainpy import math as bm
from brainpy._src.delay import Delay, VariableDelay, DataDelay
-from brainpy._src.dynsys import DynamicalSystem, Projection, Dynamics
+from brainpy._src.dynsys import DynamicalSystem, Projection, Dynamic
from brainpy._src.mixin import JointType, ParamDescInit, ReturnInfo, AutoDelaySupp, BindCondData, AlignPost
__all__ = [
@@ -81,7 +81,7 @@ def __init__(
delay: Union[None, int, float],
comm: Callable,
out: JointType[DynamicalSystem, BindCondData],
- post: Dynamics,
+ post: Dynamic,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
@@ -92,7 +92,7 @@ def __init__(
assert isinstance(syn, ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]])
assert isinstance(comm, Callable)
assert isinstance(out, JointType[DynamicalSystem, BindCondData])
- assert isinstance(post, Dynamics)
+ assert isinstance(post, Dynamic)
self.pre = pre
self.post = post
self.comm = comm
@@ -140,7 +140,7 @@ def __init__(
comm: Callable,
syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]],
out: ParamDescInit[JointType[DynamicalSystem, BindCondData]],
- post: Dynamics,
+ post: Dynamic,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
@@ -151,7 +151,7 @@ def __init__(
assert isinstance(comm, Callable)
assert isinstance(syn, ParamDescInit[JointType[DynamicalSystem, AlignPost]])
assert isinstance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]])
- assert isinstance(post, Dynamics)
+ assert isinstance(post, Dynamic)
self.pre = pre
self.post = post
self.comm = comm
diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py
index 131ad925a..8a096ddf9 100644
--- a/brainpy/_src/dynsys.py
+++ b/brainpy/_src/dynsys.py
@@ -1,11 +1,10 @@
# -*- coding: utf-8 -*-
+import collections
import gc
import inspect
-from typing import Union, Dict, Callable, Sequence, Optional, Tuple, Any
-import collections
+from typing import Union, Dict, Callable, Sequence, Optional, Any
-import jax
import numpy as np
from brainpy import tools, math as bm
@@ -24,7 +23,7 @@
'DynSysGroup', 'Network', 'Sequential',
# category
- 'Dynamics', 'Projection',
+ 'Dynamic', 'Projection', 'AnnLayer',
]
SLICE_VARS = 'slice_vars'
@@ -356,18 +355,18 @@ def update(self, *args, **kwargs):
node()
# update nodes of dynamics
- for node in nodes.subset(Dynamics).values():
+ for node in nodes.subset(Dynamic).values():
node()
# update nodes with other types, including delays, ...
- for node in nodes.not_subset(Dynamics).not_subset(Projection).values():
+ for node in nodes.not_subset(Dynamic).not_subset(Projection).values():
node()
def reset_state(self, batch_size=None):
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().not_subset(DynView)
# reset dynamics
- for node in nodes.subset(Dynamics).values():
+ for node in nodes.subset(Dynamic).values():
node.reset_state(batch_size)
# reset projections
@@ -375,7 +374,7 @@ def reset_state(self, batch_size=None):
node.reset_state(batch_size)
# reset other types of nodes, including delays, ...
- for node in nodes.not_subset(Dynamics).not_subset(Projection).values():
+ for node in nodes.not_subset(Dynamic).not_subset(Projection).values():
node.reset_state(batch_size)
@@ -513,7 +512,7 @@ def reset_state(self, *args, **kwargs):
pass
-class Dynamics(DynamicalSystem):
+class Dynamic(DynamicalSystem):
"""Base class to model dynamics.
There are several essential attributes:
@@ -627,7 +626,14 @@ def __getitem__(self, item):
return DynView(target=self, index=item)
-class DynView(Dynamics):
+class AnnLayer(DynamicalSystem):
+ """Base class for a layer of artificial neural network."""
+
+ def reset_state(self, *args, **kwargs):
+ pass
+
+
+class DynView(Dynamic):
"""DSView, an object used to get a view of a dynamical system instance.
It can get a subset view of variables in a dynamical system instance.
@@ -642,13 +648,13 @@ class DynView(Dynamics):
def __init__(
self,
- target: Dynamics,
+ target: Dynamic,
index: Union[slice, Sequence, ArrayType],
name: Optional[str] = None,
):
# check target
- if not isinstance(target, Dynamics):
- raise TypeError(f'Should be instance of {Dynamics.__name__}, but we got {type(target)}.')
+ if not isinstance(target, Dynamic):
+ raise TypeError(f'Should be instance of {Dynamic.__name__}, but we got {type(target)}.')
self.target = target # the target object to slice
# check slicing
@@ -687,7 +693,7 @@ def __init__(
# sub-nodes
nodes = target.nodes(method='relative', level=1, include_self=False).subset(DynamicalSystem)
for k, node in nodes.items():
- if isinstance(node, Dynamics):
+ if isinstance(node, Dynamic):
node = DynView(node, self.index)
else:
node = DynView(node, self.index)
diff --git a/brainpy/_src/layer.py b/brainpy/_src/layer.py
deleted file mode 100644
index af0b4e2fc..000000000
--- a/brainpy/_src/layer.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from brainpy._src.dynsys import DynamicalSystem
-
-
-class Layer(DynamicalSystem):
- """Base class for a layer of artificial neural network."""
-
- def reset_state(self, *args, **kwargs):
- pass
diff --git a/brainpy/_src/losses/base.py b/brainpy/_src/losses/base.py
index e8f6434fa..e1cfecf28 100644
--- a/brainpy/_src/losses/base.py
+++ b/brainpy/_src/losses/base.py
@@ -1,6 +1,6 @@
from typing import Optional
-from brainpy._src.layer import Layer
+from brainpy._src.dynsys import AnnLayer
__all__ = [
'Loss',
@@ -8,7 +8,7 @@
]
-class Loss(Layer):
+class Loss(AnnLayer):
reduction: str
def __init__(self, reduction: str = 'mean') -> None:
diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py
index 143c8884f..547529076 100644
--- a/brainpy/_src/mixin.py
+++ b/brainpy/_src/mixin.py
@@ -10,7 +10,6 @@
from brainpy import math as bm, tools
from brainpy._src.initialize import parameter
-
from brainpy.types import ArrayType
if sys.version_info.minor > 8:
@@ -23,6 +22,7 @@
__all__ = [
'MixIn',
'ParamDesc',
+ 'ParamDescInit',
'AlignPost',
'AutoDelaySupp',
'NoSH',
diff --git a/brainpy/_src/tests/test_mixin.py b/brainpy/_src/tests/test_mixin.py
index 1544a1f33..1352d47b7 100644
--- a/brainpy/_src/tests/test_mixin.py
+++ b/brainpy/_src/tests/test_mixin.py
@@ -18,7 +18,7 @@ def test2(self):
class TestJointType(unittest.TestCase):
def test1(self):
T = bp.mixin.JointType[bp.DynamicalSystem]
- self.assertTrue(isinstance(bp.dnn.Layer(), T))
+ self.assertTrue(isinstance(bp.AnnLayer(), T))
T = bp.mixin.JointType[bp.DynamicalSystem, bp.mixin.ParamDesc]
self.assertTrue(isinstance(bp.dyn.Expon(1), T))
diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py
index 5d94717c4..8bcc487da 100644
--- a/brainpy/dyn/base.py
+++ b/brainpy/dyn/base.py
@@ -1,6 +1,6 @@
from brainpy._src.dyn.base import (
- Dynamics,
+ Dynamic,
NeuDyn,
SynDyn,
IonChaDyn,
From d3fd10f14bcec9dda58888ea38b21658add39db1 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sun, 9 Jul 2023 16:35:56 +0800
Subject: [PATCH 019/326] update docs
---
.gitignore | 1 -
README.md | 5 +-
brainpy/__init__.py | 2 +-
brainpy/_src/math/object_transform/jit.py | 10 ----
brainpy/dnn/others.py | 5 ++
brainpy/dyn/base.py | 1 -
docs/apis/channels.rst | 10 ++++
docs/apis/layers.rst | 10 ++++
docs/apis/neurons.rst | 73 +++++++++++++++++++++++
docs/apis/rates.rst | 16 +++++
docs/apis/synapses.rst | 52 ++++++++++++++++
docs/apis/synouts.rst | 28 +++++++++
docs/apis/synplast.rst | 20 +++++++
docs/auto_generater.py | 54 ++++++++---------
docs/conf.py | 11 ++--
docs/index.rst | 17 ++++--
16 files changed, 261 insertions(+), 54 deletions(-)
create mode 100644 docs/apis/channels.rst
create mode 100644 docs/apis/layers.rst
create mode 100644 docs/apis/neurons.rst
create mode 100644 docs/apis/rates.rst
create mode 100644 docs/apis/synapses.rst
create mode 100644 docs/apis/synouts.rst
create mode 100644 docs/apis/synplast.rst
diff --git a/.gitignore b/.gitignore
index ab1abb6ae..dec4fa91d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -225,4 +225,3 @@ cython_debug/
/docs/tutorial_advanced/data/
/my_tests/
/examples/dynamics_simulation/Joglekar_2018_data/
-/docs/apis/
diff --git a/README.md b/README.md
index fec353b71..a037ffbc4 100644
--- a/README.md
+++ b/README.md
@@ -9,12 +9,11 @@
-
+
+
-
-
BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax), [Numba](https://github.com/numba/numba), and other JIT compilers). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc.
- **Website (documentation and APIs)**: https://brainpy.readthedocs.io/en/latest
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index 89e407e5e..68e72c21c 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -61,7 +61,7 @@
Network as Network,
Dynamic as Dynamic, # category
Projection as Projection,
- AnnLayer as Layer,
+ AnnLayer as AnnLayer,
)
DynamicalSystemNS = DynamicalSystem
diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py
index de163c54f..42111dba0 100644
--- a/brainpy/_src/math/object_transform/jit.py
+++ b/brainpy/_src/math/object_transform/jit.py
@@ -77,7 +77,6 @@ def __init__(
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
donate_argnums: Union[int, Iterable[int]] = (),
- device: Optional[Any] = None,
inline: bool = False,
keep_unused: bool = False,
abstracted_axes: Optional[Any] = None,
@@ -106,7 +105,6 @@ def __init__(
self._static_argnums = _seq_of_int(static_argnums)
self._static_argnames = _seq_of_str(static_argnames)
self._donate_argnums = donate_argnums
- self._device = device
self._inline = inline
self._keep_unused = keep_unused
self._abstracted_axes = abstracted_axes
@@ -151,7 +149,6 @@ def __call__(self, *args, **kwargs):
static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums),
static_argnames=self._static_argnames,
donate_argnums=self._donate_argnums,
- device=self._device,
inline=self._inline,
keep_unused=self._keep_unused,
abstracted_axes=self._abstracted_axes,
@@ -231,10 +228,8 @@ def jit(
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
donate_argnums: Union[int, Sequence[int]] = (),
- device: Optional[Any] = None,
inline: bool = False,
keep_unused: bool = False,
- backend: Optional[str] = None,
abstracted_axes: Optional[Any] = None,
# deprecated
@@ -311,7 +306,6 @@ def jit(
static_argnums=static_argnums,
static_argnames=static_argnames,
donate_argnums=donate_argnums,
- device=device,
inline=inline,
keep_unused=keep_unused,
abstracted_axes=abstracted_axes,
@@ -323,7 +317,6 @@ def jit(
static_argnums=static_argnums,
static_argnames=static_argnames,
donate_argnums=donate_argnums,
- device=device,
inline=inline,
keep_unused=keep_unused,
abstracted_axes=abstracted_axes,
@@ -337,7 +330,6 @@ def cls_jit(
func: Callable = None,
static_argnums: Union[int, Iterable[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
- device: Optional[Any] = None,
inline: bool = False,
keep_unused: bool = False,
abstracted_axes: Optional[Any] = None,
@@ -381,7 +373,6 @@ def cls_jit(
return lambda f: _make_jit_fun(fun=f,
static_argnums=static_argnums,
static_argnames=static_argnames,
- device=device,
inline=inline,
keep_unused=keep_unused,
abstracted_axes=abstracted_axes,
@@ -390,7 +381,6 @@ def cls_jit(
return _make_jit_fun(fun=func,
static_argnums=static_argnums,
static_argnames=static_argnames,
- device=device,
inline=inline,
keep_unused=keep_unused,
abstracted_axes=abstracted_axes,
diff --git a/brainpy/dnn/others.py b/brainpy/dnn/others.py
index be4a8f846..958c155a1 100644
--- a/brainpy/dnn/others.py
+++ b/brainpy/dnn/others.py
@@ -3,3 +3,8 @@
from brainpy._src.dnn.dropout import (
Dropout as Dropout,
)
+from brainpy._src.dnn.function import (
+ Activation,
+ Flatten,
+ FunAsLayer,
+)
diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py
index 8bcc487da..0553d2658 100644
--- a/brainpy/dyn/base.py
+++ b/brainpy/dyn/base.py
@@ -1,6 +1,5 @@
from brainpy._src.dyn.base import (
- Dynamic,
NeuDyn,
SynDyn,
IonChaDyn,
diff --git a/docs/apis/channels.rst b/docs/apis/channels.rst
new file mode 100644
index 000000000..cad21004d
--- /dev/null
+++ b/docs/apis/channels.rst
@@ -0,0 +1,10 @@
+``brainpy.channels`` module
+===========================
+
+.. currentmodule:: brainpy.channels
+.. automodule:: brainpy.channels
+
+.. contents::
+ :local:
+ :depth: 1
+
diff --git a/docs/apis/layers.rst b/docs/apis/layers.rst
new file mode 100644
index 000000000..46fcdd905
--- /dev/null
+++ b/docs/apis/layers.rst
@@ -0,0 +1,10 @@
+``brainpy.layers`` module
+===========================
+
+.. currentmodule:: brainpy.layers
+.. automodule:: brainpy.layers
+
+.. contents::
+ :local:
+ :depth: 1
+
diff --git a/docs/apis/neurons.rst b/docs/apis/neurons.rst
new file mode 100644
index 000000000..5c53a4a4f
--- /dev/null
+++ b/docs/apis/neurons.rst
@@ -0,0 +1,73 @@
+``brainpy.neurons`` module
+==========================
+
+.. currentmodule:: brainpy.neurons
+.. automodule:: brainpy.neurons
+
+.. contents::
+ :local:
+ :depth: 1
+
+Biological Models
+-----------------
+
+.. autosummary::
+ :toctree: generated/
+
+ HH
+ MorrisLecar
+ PinskyRinzelModel
+ WangBuzsakiModel
+
+
+Fractional-order Models
+-----------------------
+
+.. autosummary::
+ :toctree: generated/
+
+ FractionalNeuron
+ FractionalFHR
+ FractionalIzhikevich
+
+
+Reduced Models
+--------------
+
+.. autosummary::
+ :toctree: generated/
+
+ LeakyIntegrator
+ LIF
+ ExpIF
+ AdExIF
+ QuaIF
+ AdQuaIF
+ GIF
+ ALIFBellec2020
+ Izhikevich
+ HindmarshRose
+ FHN
+
+
+Noise Models
+------------
+
+.. autosummary::
+ :toctree: generated/
+
+ OUProcess
+
+
+Input Models
+------------
+
+.. autosummary::
+ :toctree: generated/
+
+ InputGroup
+ OutputGroup
+ SpikeTimeGroup
+ PoissonGroup
+
+
diff --git a/docs/apis/rates.rst b/docs/apis/rates.rst
new file mode 100644
index 000000000..3c56f148f
--- /dev/null
+++ b/docs/apis/rates.rst
@@ -0,0 +1,16 @@
+``brainpy.rates`` module
+========================
+
+.. currentmodule:: brainpy.rates
+.. automodule:: brainpy.rates
+
+.. autosummary::
+ :toctree: generated/
+
+ RateModel
+ FHN
+ FeedbackFHN
+ QIF
+ StuartLandauOscillator
+ WilsonCowanModel
+ ThresholdLinearModel
diff --git a/docs/apis/synapses.rst b/docs/apis/synapses.rst
new file mode 100644
index 000000000..b79f3fde1
--- /dev/null
+++ b/docs/apis/synapses.rst
@@ -0,0 +1,52 @@
+``brainpy.synapses`` module
+===========================
+
+.. currentmodule:: brainpy.synapses
+.. automodule:: brainpy.synapses
+
+.. contents::
+ :local:
+ :depth: 1
+
+Synaptic Dynamics
+-----------------
+
+.. autosummary::
+ :toctree: generated/
+
+ Delta
+ Exponential
+ DualExponential
+ Alpha
+ NMDA
+ PoissonInput
+ AMPA
+ GABAa
+ BioNMDA
+ DelayCoupling
+ DiffusiveCoupling
+ AdditiveCoupling
+ GapJunction
+
+
+Synaptic Output
+---------------
+
+.. autosummary::
+ :toctree: generated/
+
+ COBA
+ CUBA
+ MgBlock
+
+
+Synaptic Plasticity
+-------------------
+
+.. autosummary::
+ :toctree: generated/
+
+ STD
+ STP
+
+
diff --git a/docs/apis/synouts.rst b/docs/apis/synouts.rst
new file mode 100644
index 000000000..4ea547d59
--- /dev/null
+++ b/docs/apis/synouts.rst
@@ -0,0 +1,28 @@
+``brainpy.synouts`` module
+===========================
+
+.. currentmodule:: brainpy.synouts
+.. automodule:: brainpy.synouts
+
+.. contents::
+ :local:
+ :depth: 1
+
+.. autosummary::
+ :toctree: generated/
+
+ COBA
+ CUBA
+ MgBlock
+
+
+Synaptic Plasticity
+-------------------
+
+.. autosummary::
+ :toctree: generated/
+
+ STD
+ STP
+
+
diff --git a/docs/apis/synplast.rst b/docs/apis/synplast.rst
new file mode 100644
index 000000000..b98938b52
--- /dev/null
+++ b/docs/apis/synplast.rst
@@ -0,0 +1,20 @@
+``brainpy.synplast`` module
+===========================
+
+.. currentmodule:: brainpy.synplast
+.. automodule:: brainpy.synplast
+
+.. contents::
+ :local:
+ :depth: 1
+
+Synaptic Plasticity
+-------------------
+
+.. autosummary::
+ :toctree: generated/
+
+ STD
+ STP
+
+
diff --git a/docs/auto_generater.py b/docs/auto_generater.py
index b6a1eb838..77b6332f9 100644
--- a/docs/auto_generater.py
+++ b/docs/auto_generater.py
@@ -379,24 +379,20 @@ def generate_inputs_docs():
header='``brainpy.inputs`` module')
-def generate_layers_docs():
+def generate_dnn_docs():
_write_subsections_v2(
- 'brainpy._src.dnn',
+ 'brainpy.dnn',
'brainpy.dnn',
'apis/auto/dnn.rst',
subsections={
- 'base': 'Basic ANN Layer Class',
'activations': 'Non-linear Activations',
'conv': 'Convolutional Layers',
- 'dropout': 'Dropout Layers',
- 'function': 'Function Layers',
'linear': 'Dense Connection Layers',
'normalization': 'Normalization Layers',
- 'nvar': 'NVAR Layers',
'pooling': 'Pooling Layers',
- 'reservoir': 'Reservoir Layers',
- 'rnncells': 'Artificial Recurrent Layers',
- 'interoperation_flax': 'Interoperation with Flax',
+ 'recurrent': 'Artificial Recurrent Layers',
+ 'interoperation': 'Interoperation with Flax',
+ 'others': 'Other Layers',
}
)
@@ -407,11 +403,15 @@ def generate_dyn_docs():
'brainpy.dyn',
'apis/auto/dyn.rst',
subsections={
+ 'base': 'Base Classes',
+ 'ions': 'Ion Dynamics',
'channels': 'Ion Channel Dynamics',
'neurons': 'Neuron Dynamics',
'synapses': 'Synaptic Dynamics',
'projections': 'Synaptic Projections',
'others': 'Common Dynamical Models',
+ 'outs': 'Synaptic Output Models',
+ 'rates': 'Population Rate Models',
}
)
@@ -474,16 +474,17 @@ def generate_running_docs():
def generate_synapses_docs():
- _write_subsections_v2(
- 'brainpy.synapses',
- 'brainpy.synapses',
- 'apis/auto/synapses.rst',
- subsections={
- 'dynamics': 'Synaptic Dynamics',
- 'synouts': 'Synaptic Output',
- 'synplast': 'Synaptic Plasticity',
- }
- )
+ _write_module(module_name='brainpy.synapses',
+ filename='apis/auto/synapses.rst',
+ header='``brainpy.synapses`` module')
+
+ _write_module(module_name='brainpy.synouts',
+ filename='apis/auto/synouts.rst',
+ header='``brainpy.synouts`` module')
+
+ _write_module(module_name='brainpy.synplast',
+ filename='apis/auto/synplast.rst',
+ header='``brainpy.synplast`` module')
def generate_brainpy_docs():
@@ -498,17 +499,12 @@ def generate_brainpy_docs():
'sdeint',
'fdeint'],
'Building Dynamical System': ['DynamicalSystem',
- 'Container',
+ 'DynSysGroup',
'Sequential',
'Network',
- 'NeuGroup',
- 'SynConn',
- 'SynOut',
- 'SynSTP',
- 'SynLTP',
- 'TwoEndConn',
- 'CondNeuGroup',
- 'Channel',
+ 'Dynamic',
+ 'Projection',
+ 'AnnLayer',
],
'Simulating Dynamical System': ['DSRunner'],
'Training Dynamical System': ['DSTrainer',
@@ -518,7 +514,7 @@ def generate_brainpy_docs():
'ForceTrainer',
'OfflineTrainer',
'RidgeTrainer'],
- 'Dynamical System Helpers': ['DSPartial', 'NoSharedArg', 'LoopOverTime'],
+ 'Dynamical System Helpers': ['LoopOverTime'],
}
)
diff --git a/docs/conf.py b/docs/conf.py
index 344939a97..f584fb7a8 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -23,22 +23,23 @@
os.makedirs('apis/auto/', exist_ok=True)
auto_generater.generate_analysis_docs()
auto_generater.generate_connect_docs()
-auto_generater.generate_channels_docs()
auto_generater.generate_encoding_docs()
auto_generater.generate_initialize_docs()
auto_generater.generate_inputs_docs()
-auto_generater.generate_layers_docs()
+auto_generater.generate_dnn_docs()
auto_generater.generate_dyn_docs()
auto_generater.generate_losses_docs()
auto_generater.generate_measure_docs()
-auto_generater.generate_neurons_docs()
auto_generater.generate_optim_docs()
-auto_generater.generate_rates_docs()
auto_generater.generate_running_docs()
-auto_generater.generate_synapses_docs()
auto_generater.generate_brainpy_docs()
auto_generater.generate_integrators_doc()
auto_generater.generate_math_docs()
+# auto_generater.generate_channels_docs()
+# auto_generater.generate_layers_docs()
+# auto_generater.generate_neurons_docs()
+# auto_generater.generate_rates_docs()
+# auto_generater.generate_synapses_docs()
changelogs = [
diff --git a/docs/index.rst b/docs/index.rst
index 071d027aa..9d1b55d5e 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -99,10 +99,6 @@ general-purpose Brain Dynamics Programming (BDP). Among its key ingredients, Bra
apis/auto/math.rst
apis/auto/dnn.rst
apis/auto/dyn.rst
- apis/auto/channels.rst
- apis/auto/neurons.rst
- apis/auto/rates.rst
- apis/auto/synapses.rst
apis/auto/integrators.rst
apis/auto/analysis.rst
apis/auto/connect.rst
@@ -116,6 +112,19 @@ general-purpose Brain Dynamics Programming (BDP). Among its key ingredients, Bra
apis/auto/changelog.rst
+.. toctree::
+ :maxdepth: 1
+ :caption: Deprecated APIs
+
+ apis/channels.rst
+ apis/neurons.rst
+ apis/rates.rst
+ apis/synapses.rst
+ apis/synouts.rst
+ apis/synplast.rst
+ apis/layers.rst
+
+
Indices and tables
==================
From 99aebf3d4a950e0dd6fe297308da232f52f13078 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sun, 9 Jul 2023 17:20:56 +0800
Subject: [PATCH 020/326] update deprecated API docs
---
docs/apis/channels.rst | 14 +++++++++++---
docs/apis/layers.rst | 15 ++++++++++++---
docs/apis/neurons.rst | 39 ++++++++++++++++-----------------------
docs/apis/rates.rst | 24 ++++++++++++++----------
docs/apis/synapses.rst | 9 ++++++---
docs/apis/synouts.rst | 14 --------------
docs/apis/synplast.rst | 7 -------
7 files changed, 59 insertions(+), 63 deletions(-)
diff --git a/docs/apis/channels.rst b/docs/apis/channels.rst
index cad21004d..3e8bd600e 100644
--- a/docs/apis/channels.rst
+++ b/docs/apis/channels.rst
@@ -4,7 +4,15 @@
.. currentmodule:: brainpy.channels
.. automodule:: brainpy.channels
-.. contents::
- :local:
- :depth: 1
+
+``brainpy.channels`` module is deprecated has been renamed as ``brainpy.dnn``.
+Although all models can be accessed through ``brainpy.channels.xxx`` as in old
+version of BrainPy, we recommend users to use ``brainpy.dyn.xxx`` instead.
+
+In general, from ``brainpy>=2.4.3``, we provide two modules:
+
+- ``brainpy.dnn`` for modeling deep neural networks
+- ``brainpy.dyn`` for modeling dynamics models
+
+
diff --git a/docs/apis/layers.rst b/docs/apis/layers.rst
index 46fcdd905..89542128b 100644
--- a/docs/apis/layers.rst
+++ b/docs/apis/layers.rst
@@ -4,7 +4,16 @@
.. currentmodule:: brainpy.layers
.. automodule:: brainpy.layers
-.. contents::
- :local:
- :depth: 1
+
+``brainpy.layers`` module is deprecated and has been renamed as ``brainpy.dnn``. Although all models
+can be accessed through ``brainpy.layers.xxx`` as old version of BrainPy, we recommend
+users to use ``brainpy.dnn.xxx`` instead.
+
+
+In general, from ``brainpy>=2.4.3``, we provide two modules:
+
+- ``brainpy.dnn`` for modeling deep neural networks
+- ``brainpy.dyn`` for modeling brain dynamics models
+
+
diff --git a/docs/apis/neurons.rst b/docs/apis/neurons.rst
index 5c53a4a4f..85a859c7e 100644
--- a/docs/apis/neurons.rst
+++ b/docs/apis/neurons.rst
@@ -8,6 +8,20 @@
:local:
:depth: 1
+
+From ``brainpy>=2.4.3``, most of models in ``brainpy.neurons`` have been reimplemented with ``brainpy.dyn`` module.
+
+However, ``brainpy.neurons`` is still independent from ``brainpy.dyn`` module.
+
+The most significant difference between models in ``brainpy.neurons`` and ``brainpy.dyn`` is that:
+
+- the former only support the integration style without liquid time constant (which means that
+ the time constants in these neuron models are fixed once initialization)
+- the former supports the integration with SDE by specifying the ``noise`` parameter. For example,
+ ``brainpy.neurons.HH(size, ..., noise=1.)``
+- the former has one additional ``input`` variable for receiving external inputs.
+
+
Biological Models
-----------------
@@ -44,30 +58,9 @@ Reduced Models
QuaIF
AdQuaIF
GIF
- ALIFBellec2020
Izhikevich
HindmarshRose
FHN
-
-
-Noise Models
-------------
-
-.. autosummary::
- :toctree: generated/
-
- OUProcess
-
-
-Input Models
-------------
-
-.. autosummary::
- :toctree: generated/
-
- InputGroup
- OutputGroup
- SpikeTimeGroup
- PoissonGroup
-
+ ALIFBellec2020
+ LIF_SFA_Bellec2020
diff --git a/docs/apis/rates.rst b/docs/apis/rates.rst
index 3c56f148f..c0fde5cd9 100644
--- a/docs/apis/rates.rst
+++ b/docs/apis/rates.rst
@@ -4,13 +4,17 @@
.. currentmodule:: brainpy.rates
.. automodule:: brainpy.rates
-.. autosummary::
- :toctree: generated/
-
- RateModel
- FHN
- FeedbackFHN
- QIF
- StuartLandauOscillator
- WilsonCowanModel
- ThresholdLinearModel
+
+
+``brainpy.rates`` module is deprecated and has been renamed as ``brainpy.dyn``. Although all models
+can be accessed through ``brainpy.rates.xxx`` as old version of BrainPy, we recommend
+users to use ``brainpy.dyn.xxx`` instead.
+
+
+In general, from ``brainpy>=2.4.3``, we provide two modules:
+
+- ``brainpy.dnn`` for modeling deep neural networks
+- ``brainpy.dyn`` for modeling brain dynamics models
+
+
+
diff --git a/docs/apis/synapses.rst b/docs/apis/synapses.rst
index b79f3fde1..82e4fec35 100644
--- a/docs/apis/synapses.rst
+++ b/docs/apis/synapses.rst
@@ -4,9 +4,12 @@
.. currentmodule:: brainpy.synapses
.. automodule:: brainpy.synapses
-.. contents::
- :local:
- :depth: 1
+
+
+From ``brainpy>=2.4.3``, most of models in ``brainpy.synapses`` have been reimplemented with ``brainpy.dyn`` module.
+
+However, ``brainpy.synapses`` is still independent from ``brainpy.dyn`` module.
+
Synaptic Dynamics
-----------------
diff --git a/docs/apis/synouts.rst b/docs/apis/synouts.rst
index 4ea547d59..a82e0732b 100644
--- a/docs/apis/synouts.rst
+++ b/docs/apis/synouts.rst
@@ -4,9 +4,6 @@
.. currentmodule:: brainpy.synouts
.. automodule:: brainpy.synouts
-.. contents::
- :local:
- :depth: 1
.. autosummary::
:toctree: generated/
@@ -15,14 +12,3 @@
CUBA
MgBlock
-
-Synaptic Plasticity
--------------------
-
-.. autosummary::
- :toctree: generated/
-
- STD
- STP
-
-
diff --git a/docs/apis/synplast.rst b/docs/apis/synplast.rst
index b98938b52..5ee1efba9 100644
--- a/docs/apis/synplast.rst
+++ b/docs/apis/synplast.rst
@@ -4,13 +4,6 @@
.. currentmodule:: brainpy.synplast
.. automodule:: brainpy.synplast
-.. contents::
- :local:
- :depth: 1
-
-Synaptic Plasticity
--------------------
-
.. autosummary::
:toctree: generated/
From 0510268a05e4ce7272648a030d09548b2c13904b Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sun, 9 Jul 2023 20:13:30 +0800
Subject: [PATCH 021/326] update docs
---
.readthedocs.yml | 4 ----
brainpy/_add_deprecations.py | 1 -
brainpy/_src/connect/custom_conn.py | 4 ++--
3 files changed, 2 insertions(+), 7 deletions(-)
diff --git a/.readthedocs.yml b/.readthedocs.yml
index 0086e9718..82cdd086b 100644
--- a/.readthedocs.yml
+++ b/.readthedocs.yml
@@ -14,10 +14,6 @@ build:
sphinx:
configuration: docs/conf.py
-# Optionally build your docs in additional formats such as PDF and ePub
-formats:
- - epub
-
# Optionally set the version of Python and requirements required to build your docs
python:
install:
diff --git a/brainpy/_add_deprecations.py b/brainpy/_add_deprecations.py
index 77208381c..bd397ba24 100644
--- a/brainpy/_add_deprecations.py
+++ b/brainpy/_add_deprecations.py
@@ -88,7 +88,6 @@
# synapses
'SynConn': ('brainpy.dyn.SynConn', 'brainpy.synapses.SynConn', synapses.SynConn),
- # 'SynLTP': ('brainpy.dyn.SynLTP', 'brainpy.synapses.SynLTP', synapses.SynLTP),
'SynSTP': ('brainpy.dyn.SynSTP', 'brainpy.synapses.SynSTP', synapses.SynSTP),
'TwoEndConn': ('brainpy.dyn.TwoEndConn', 'brainpy.synapses.TwoEndConn', synapses.TwoEndConn),
'DeltaSynapse': ('brainpy.dyn.DeltaSynapse', 'brainpy.synapses.Delta', synapses.DeltaSynapse),
diff --git a/brainpy/_src/connect/custom_conn.py b/brainpy/_src/connect/custom_conn.py
index ecf1283e0..ca2cb6910 100644
--- a/brainpy/_src/connect/custom_conn.py
+++ b/brainpy/_src/connect/custom_conn.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-
+import jax
import jax.numpy as jnp
import numpy as np
@@ -22,7 +22,7 @@ class MatConn(TwoEndConnector):
def __init__(self, conn_mat, **kwargs):
super(MatConn, self).__init__(**kwargs)
- assert isinstance(conn_mat, (np.ndarray, bm.Array, jnp.ndarray)) and conn_mat.ndim == 2
+ assert isinstance(conn_mat, (np.ndarray, bm.Array, jax.Array)) and conn_mat.ndim == 2
self.pre_num, self.post_num = conn_mat.shape
self.pre_size, self.post_size = (self.pre_num,), (self.post_num,)
From 7462d7a1f2775e2de2256b7d5372f294271f58e0 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sun, 9 Jul 2023 23:06:30 +0800
Subject: [PATCH 022/326] add note for API changing
---
docs/index.rst | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/docs/index.rst b/docs/index.rst
index 9d1b55d5e..d0d9d0f45 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -22,6 +22,12 @@ general-purpose Brain Dynamics Programming (BDP). Among its key ingredients, Bra
.. _BrainPy: https://github.com/brainpy/BrainPy
+.. note::
+ BrainPy is still a research experimental project.
+ APIs may be changed over time. Please always keeps
+ in mind which BrainPy version are you using.
+
+
.. toctree::
:maxdepth: 1
From 9baa7f81fc8dae90ce483d1a26bc45d211755f8b Mon Sep 17 00:00:00 2001
From: chaoming
Date: Mon, 10 Jul 2023 15:11:11 +0800
Subject: [PATCH 023/326] add API documentation of `brainpy.mixin` module
---
docs/auto_generater.py | 6 ++++++
docs/index.rst | 5 +++--
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/docs/auto_generater.py b/docs/auto_generater.py
index 77b6332f9..3bca449e7 100644
--- a/docs/auto_generater.py
+++ b/docs/auto_generater.py
@@ -379,6 +379,12 @@ def generate_inputs_docs():
header='``brainpy.inputs`` module')
+def generate_mixin_docs():
+ _write_module(module_name='brainpy.mixin',
+ filename='apis/auto/mixin.rst',
+ header='``brainpy.mixin`` module')
+
+
def generate_dnn_docs():
_write_subsections_v2(
'brainpy.dnn',
diff --git a/docs/index.rst b/docs/index.rst
index d0d9d0f45..57b039ac6 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -23,9 +23,9 @@ general-purpose Brain Dynamics Programming (BDP). Among its key ingredients, Bra
.. note::
- BrainPy is still a research experimental project.
+ BrainPy is still an experimental research project.
APIs may be changed over time. Please always keeps
- in mind which BrainPy version are you using.
+ in mind what BrainPy version you are using.
@@ -115,6 +115,7 @@ general-purpose Brain Dynamics Programming (BDP). Among its key ingredients, Bra
apis/auto/measure.rst
apis/auto/optim.rst
apis/auto/running.rst
+ apis/auto/mixin.rst
apis/auto/changelog.rst
From efb3ab4dd2ced29ce53f81e96c5b2cc6fafeb5a3 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Mon, 10 Jul 2023 16:07:07 +0800
Subject: [PATCH 024/326] advanced tutorial multi-level
---
docs/index.rst | 13 ++++---------
docs/tutorial_advanced/analysis.rst | 7 +++++++
docs/tutorial_advanced/interoperation.rst | 9 +++++++++
docs/tutorial_advanced/math.rst | 9 +++++++++
4 files changed, 29 insertions(+), 9 deletions(-)
create mode 100644 docs/tutorial_advanced/analysis.rst
create mode 100644 docs/tutorial_advanced/interoperation.rst
create mode 100644 docs/tutorial_advanced/math.rst
diff --git a/docs/index.rst b/docs/index.rst
index 57b039ac6..bf1a38560 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -59,17 +59,12 @@ general-purpose Brain Dynamics Programming (BDP). Among its key ingredients, Bra
.. toctree::
- :maxdepth: 1
+ :maxdepth: 2
:caption: Advanced Tutorials
- tutorial_advanced/how_to_debug.ipynb
- tutorial_advanced/gotchas_of_brainpy_transforms.ipynb
- tutorial_advanced/advanced_lowdim_analysis.ipynb
- tutorial_advanced/differentiation.ipynb
- tutorial_advanced/integrate_flax_into_brainpy.ipynb
- tutorial_advanced/integrate_bp_lif_into_flax.ipynb
- tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb
-
+ tutorial_advanced/math.rst
+ tutorial_advanced/interoperation.rst
+ tutorial_advanced/analysis.rst
.. toctree::
diff --git a/docs/tutorial_advanced/analysis.rst b/docs/tutorial_advanced/analysis.rst
new file mode 100644
index 000000000..29d8d3886
--- /dev/null
+++ b/docs/tutorial_advanced/analysis.rst
@@ -0,0 +1,7 @@
+Interoperation
+================
+
+.. toctree::
+ :maxdepth: 1
+
+ advanced_lowdim_analysis.ipynb
\ No newline at end of file
diff --git a/docs/tutorial_advanced/interoperation.rst b/docs/tutorial_advanced/interoperation.rst
new file mode 100644
index 000000000..7e1857765
--- /dev/null
+++ b/docs/tutorial_advanced/interoperation.rst
@@ -0,0 +1,9 @@
+Interoperation
+================
+
+.. toctree::
+ :maxdepth: 1
+
+ integrate_flax_into_brainpy.ipynb
+ integrate_bp_lif_into_flax.ipynb
+ integrate_bp_convlstm_into_flax.ipynb
diff --git a/docs/tutorial_advanced/math.rst b/docs/tutorial_advanced/math.rst
new file mode 100644
index 000000000..c66e31673
--- /dev/null
+++ b/docs/tutorial_advanced/math.rst
@@ -0,0 +1,9 @@
+Advanced Math
+=============
+
+.. toctree::
+ :maxdepth: 1
+
+ how_to_debug.ipynb
+ gotchas_of_brainpy_transforms.ipynb
+ differentiation.ipynb
From 0318133b6c70ce1762f7f66c10247df3cf5c5a0c Mon Sep 17 00:00:00 2001
From: GYF <1337838189@qq.com>
Date: Tue, 11 Jul 2023 10:56:09 +0800
Subject: [PATCH 025/326] test
---
brainpy/_src/dnn/activations.py | 71 ++--
brainpy/_src/dnn/conv.py | 6 +-
brainpy/_src/dnn/pooling.py | 2 +
brainpy/_src/dnn/tests/test_activation.py | 237 ++++++++++++
brainpy/_src/dnn/tests/test_conv_layers.py | 360 +++++++++++-------
brainpy/_src/dnn/tests/test_function.py | 33 ++
brainpy/_src/dnn/tests/test_linear.py | 22 +-
brainpy/_src/dnn/tests/test_normalization.py | 57 +++
brainpy/_src/dnn/tests/test_pooling_layers.py | 345 ++++++++++-------
9 files changed, 829 insertions(+), 304 deletions(-)
create mode 100644 brainpy/_src/dnn/tests/test_activation.py
create mode 100644 brainpy/_src/dnn/tests/test_function.py
create mode 100644 brainpy/_src/dnn/tests/test_normalization.py
diff --git a/brainpy/_src/dnn/activations.py b/brainpy/_src/dnn/activations.py
index e9f342319..8e087435d 100644
--- a/brainpy/_src/dnn/activations.py
+++ b/brainpy/_src/dnn/activations.py
@@ -12,7 +12,7 @@
def _inplace(inp, val, inplace):
if inplace:
- assert isinstance(input, bm.Array), 'input must be instance of brainpy.math.Array if inplace=True'
+ assert isinstance(inp, bm.Array), 'input must be instance of brainpy.math.Array if inplace=True'
inp.value = val
return inp
else:
@@ -44,7 +44,7 @@ class Threshold(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Threshold(0.1, 20)
+ >>> m = bp.dnn.Threshold(0.1, 20)
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -87,7 +87,7 @@ class ReLU(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.ReLU()
+ >>> m = bp.dnn.ReLU()
>>> input = bm.random.randn(2)
>>> output = m(input)
@@ -96,7 +96,7 @@ class ReLU(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.ReLU()
+ >>> m = bp.dnn.ReLU()
>>> input = bm.random.randn(2).unsqueeze(0)
>>> output = bm.cat((m(input), m(-input)))
"""
@@ -149,7 +149,7 @@ class RReLU(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.RReLU(0.1, 0.3)
+ >>> m = bp.dnn.RReLU(0.1, 0.3)
>>> input = bm.random.randn(2)
>>> output = m(input)
@@ -210,7 +210,7 @@ class Hardtanh(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Hardtanh(-2, 2)
+ >>> m = bp.dnn.Hardtanh(-2, 2)
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -260,7 +260,7 @@ class ReLU6(Hardtanh):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.ReLU6()
+ >>> m = bp.dnn.test_ReLU6()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -288,7 +288,7 @@ class Sigmoid(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Sigmoid()
+ >>> m = bp.dnn.Sigmoid()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -320,7 +320,7 @@ class Hardsigmoid(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Hardsigmoid()
+ >>> m = bp.dnn.Hardsigmoid()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -353,7 +353,7 @@ class Tanh(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Tanh()
+ >>> m = bp.dnn.Tanh()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -376,6 +376,8 @@ class SiLU(Layer):
in Reinforcement Learning `_ and `Swish:
a Self-Gated Activation Function `_
where the SiLU was experimented with later.
+ Args:
+ inplace: can optionally do the operation in-place. Default: ``False``
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
@@ -385,7 +387,7 @@ class SiLU(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.SiLU()
+ >>> m = bp.dnn.SiLU()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -414,6 +416,9 @@ class Mish(Layer):
.. note::
See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_
+ Args:
+ inplace: can optionally do the operation in-place. Default: ``False``
+
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
@@ -422,7 +427,7 @@ class Mish(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Mish()
+ >>> m = bp.dnn.Mish()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -465,7 +470,7 @@ class Hardswish(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Hardswish()
+ >>> m = bp.dnn.Hardswish()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -506,7 +511,7 @@ class ELU(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.ELU()
+ >>> m = bp.dnn.ELU()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -547,7 +552,7 @@ class CELU(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.CELU()
+ >>> m = bp.dnn.CELU()
>>> input = bm.random.randn(2)
>>> output = m(input)
@@ -593,7 +598,7 @@ class SELU(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.SELU()
+ >>> m = bp.dnn.SELU()
>>> input = bm.random.randn(2)
>>> output = m(input)
@@ -631,7 +636,7 @@ class GLU(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.GLU()
+ >>> m = bp.dnn.GLU()
>>> input = bm.random.randn(4, 2)
>>> output = m(input)
"""
@@ -672,7 +677,7 @@ class GELU(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.GELU()
+ >>> m = bp.dnn.GELU()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -714,7 +719,7 @@ class Hardshrink(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Hardshrink()
+ >>> m = bp.dnn.Hardshrink()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -762,7 +767,7 @@ class LeakyReLU(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.LeakyReLU(0.1)
+ >>> m = bp.dnn.LeakyReLU(0.1)
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -797,7 +802,7 @@ class LogSigmoid(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.LogSigmoid()
+ >>> m = bp.dnn.LogSigmoid()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -828,7 +833,7 @@ class Softplus(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Softplus()
+ >>> m = bp.dnn.Softplus()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -870,7 +875,7 @@ class Softshrink(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Softshrink()
+ >>> m = bp.dnn.Softshrink()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -903,8 +908,8 @@ class PReLU(Layer):
ax, & \text{ otherwise }
\end{cases}
- Here :math:`a` is a learnable parameter. When called without arguments, `bp.layers.PReLU()` uses a single
- parameter :math:`a` across all input channels. If called with `bp.layers.PReLU(nChannels)`,
+ Here :math:`a` is a learnable parameter. When called without arguments, `bp.dnn.PReLU()` uses a single
+ parameter :math:`a` across all input channels. If called with `bp.dnn.PReLU(nChannels)`,
a separate :math:`a` is used for each input channel.
@@ -933,7 +938,7 @@ class PReLU(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.PReLU()
+ >>> m = bp.dnn.PReLU()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -966,7 +971,7 @@ class Softsign(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Softsign()
+ >>> m = bp.dnn.Softsign()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -989,7 +994,7 @@ class Tanhshrink(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Tanhshrink()
+ >>> m = bp.dnn.Tanhshrink()
>>> input = bm.random.randn(2)
>>> output = m(input)
"""
@@ -1025,7 +1030,7 @@ class Softmin(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Softmin(dim=1)
+ >>> m = bp.dnn.Softmin(dim=1)
>>> input = bm.random.randn(2, 3)
>>> output = m(input)
"""
@@ -1078,7 +1083,7 @@ class Softmax(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Softmax(dim=1)
+ >>> m = bp.dnn.Softmax(dim=1)
>>> input = bm.random.randn(2, 3)
>>> output = m(input)
@@ -1115,14 +1120,14 @@ class Softmax2d(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.Softmax2d()
+ >>> m = bp.dnn.Softmax2d()
>>> # you softmax over the 2nd dimension
>>> input = bm.random.randn(2, 3, 12, 13)
>>> output = m(input)
"""
def update(self, input: ArrayType) -> ArrayType:
- assert input.dim() == 4 or input.dim() == 3, 'Softmax2d requires a 3D or 4D tensor as input'
+ assert input.ndim == 4 or input.ndim == 3, 'Softmax2d requires a 3D or 4D tensor as input'
return bm.softmax(input, -3)
@@ -1149,7 +1154,7 @@ class LogSoftmax(Layer):
>>> import brainpy as bp
>>> import brainpy.math as bm
- >>> m = bp.layers.LogSoftmax(dim=1)
+ >>> m = bp.dnn.LogSoftmax(dim=1)
>>> input = bm.random.randn(2, 3)
>>> output = m(input)
"""
diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py
index 566949579..e878e2204 100644
--- a/brainpy/_src/dnn/conv.py
+++ b/brainpy/_src/dnn/conv.py
@@ -462,6 +462,8 @@ def _check_input_dim(self, x):
class _GeneralConvTranspose(Layer):
+
+
def __init__(
self,
num_spatial_dims: int,
@@ -604,6 +606,8 @@ def __init__(
)
def _check_input_dim(self, x):
+ if isinstance(self.mode, bm.BatchingMode):
+ pass
if x.ndim != 3:
raise ValueError(f"expected 3D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
@@ -707,7 +711,7 @@ def __init__(
name: The name of the module.
"""
super().__init__(
- num_spatial_dims=1,
+ num_spatial_dims=3,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
diff --git a/brainpy/_src/dnn/pooling.py b/brainpy/_src/dnn/pooling.py
index 3ff24d8a4..07bc11024 100644
--- a/brainpy/_src/dnn/pooling.py
+++ b/brainpy/_src/dnn/pooling.py
@@ -771,6 +771,8 @@ def update(self, x):
# channel axis
channel_axis = self.channel_axis
+
+
if channel_axis:
if not 0 <= abs(channel_axis) < x.ndim:
raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py
new file mode 100644
index 000000000..2915b0f35
--- /dev/null
+++ b/brainpy/_src/dnn/tests/test_activation.py
@@ -0,0 +1,237 @@
+import brainpy.math as bm
+from absl.testing import parameterized
+from absl.testing import absltest
+import brainpy as bp
+
+
+class Test_Activation(parameterized.TestCase):
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_Threshold(self, inplace):
+ bm.random.seed()
+ threshold_layer = bp.dnn.Threshold(5, 20, inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ threshold_layer(input)
+ elif inplace == False:
+ output = threshold_layer(input)
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_ReLU(self, inplace):
+ ReLU_layer = bp.dnn.ReLU(inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ ReLU_layer(input)
+ elif inplace == False:
+ output = ReLU_layer(input)
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_RReLU(self, inplace):
+ RReLU_layer = bp.dnn.RReLU(lower=0, upper=1, inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ RReLU_layer(input)
+ elif inplace == False:
+ output = RReLU_layer(input)
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_Hardtanh(self, inplace):
+ Hardtanh_layer = bp.dnn.Hardtanh(min_val=0, max_val=1, inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ Hardtanh_layer(input)
+ elif inplace == False:
+ output = Hardtanh_layer(input)
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_ReLU6(self, inplace):
+ ReLU6_layer = bp.dnn.ReLU6(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ ReLU6_layer(input)
+ elif inplace == False:
+ output = ReLU6_layer(input)
+
+ def test_Sigmoid(self):
+ Sigmoid_layer = bp.dnn.Sigmoid()
+ input = bm.random.randn(2)
+ output = Sigmoid_layer(input)
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_Hardsigmoid(self, inplace):
+ Hardsigmoid_layer = bp.dnn.Hardsigmoid(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ Hardsigmoid_layer(input)
+ elif inplace == False:
+ output = Hardsigmoid_layer(input)
+
+ def test_Tanh(self):
+ Tanh_layer = bp.dnn.Tanh()
+ input = bm.random.randn(2)
+ output = Tanh_layer(input)
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_SiLU(self, inplace):
+ SiLU_layer = bp.dnn.SiLU(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ SiLU_layer(input)
+ elif inplace == False:
+ output = SiLU_layer(input)
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_Mish(self, inplace):
+ Mish_layer = bp.dnn.Mish(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ Mish_layer(input)
+ elif inplace == False:
+ output = Mish_layer(input)
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_Hardswish(self, inplace):
+ Hardswish_layer = bp.dnn.Hardswish(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ Hardswish_layer(input)
+ elif inplace == False:
+ output = Hardswish_layer(input)
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_ELU(self, inplace):
+ ELU_layer = bp.dnn.ELU(alpha=0.5, inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ ELU_layer(input)
+ elif inplace == False:
+ output = ELU_layer(input)
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_CELU(self, inplace):
+ CELU_layer = bp.dnn.CELU(alpha=0.5, inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ CELU_layer(input)
+ elif inplace == False:
+ output = CELU_layer(input)
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_SELU(self, inplace):
+ SELU_layer = bp.dnn.SELU(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ SELU_layer(input)
+ elif inplace == False:
+ output = SELU_layer(input)
+
+ def test_GLU(self):
+ GLU_layer = bp.dnn.GLU()
+ input = bm.random.randn(4, 2)
+ output = GLU_layer(input)
+
+ @parameterized.product(
+ approximate=['tanh', 'none']
+ )
+ def test_GELU(self, approximate):
+ GELU_layer = bp.dnn.GELU()
+ input = bm.random.randn(2)
+ output = GELU_layer(input)
+
+ def test_Hardshrink(self):
+ Hardshrink_layer = bp.dnn.Hardshrink(lambd=1)
+ input = bm.random.randn(2)
+ output = Hardshrink_layer(input)
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_LeakyReLU(self, inplace):
+ LeakyReLU_layer = bp.dnn.LeakyReLU(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ LeakyReLU_layer(input)
+ elif inplace == False:
+ output = LeakyReLU_layer(input)
+
+ def test_LogSigmoid(self):
+ LogSigmoid_layer = bp.dnn.LogSigmoid()
+ input = bm.random.randn(2)
+ output = LogSigmoid_layer(input)
+
+ @parameterized.product(
+ beta=[1, 2, 3],
+ threshold=[20, 21, 22]
+ )
+ def test_Softplus(self, beta, threshold):
+ Softplus_layer = bp.dnn.Softplus(beta=beta, threshold=threshold)
+ input = bm.random.randn(2)
+ output = Softplus_layer(input)
+
+ def test_Softshrink(self):
+ Softshrink_layer = bp.dnn.Softshrink(lambd=1)
+ input = bm.random.randn(2)
+ output = Softshrink_layer(input)
+
+ def test_PReLU(self):
+ PReLU_layer = bp.dnn.PReLU(num_parameters=2, init=0.5)
+ input = bm.random.randn(2)
+ output = PReLU_layer(input)
+
+ def test_Softsign(self):
+ Softsign_layer = bp.dnn.Softsign()
+ input = bm.random.randn(2)
+ output = Softsign_layer(input)
+
+ def test_Tanhshrink(self):
+ Tanhshrink_layer = bp.dnn.Tanhshrink()
+ input = bm.random.randn(2)
+ output = Tanhshrink_layer(input)
+
+ def test_Softmin(self):
+ Softmin_layer = bp.dnn.Softmin(dim=2)
+ input = bm.random.randn(2, 3, 4)
+ output = Softmin_layer(input)
+
+ def test_Softmax(self):
+ Softmax_layer = bp.dnn.Softmax(dim=2)
+ input = bm.random.randn(2, 3, 4)
+ output = Softmax_layer(input)
+
+ def test_Softmax2d(self):
+ Softmax2d_layer = bp.dnn.Softmax2d()
+ input = bm.random.randn(2, 3, 12, 13)
+ output = Softmax2d_layer(input)
+
+ def test_LogSoftmax(self):
+ LogSoftmax_layer = bp.dnn.LogSoftmax(dim=2)
+ input = bm.random.randn(2, 3, 4)
+ output = LogSoftmax_layer(input)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py
index 550f87883..71e63682f 100644
--- a/brainpy/_src/dnn/tests/test_conv_layers.py
+++ b/brainpy/_src/dnn/tests/test_conv_layers.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from unittest import TestCase
-
+from absl.testing import absltest
import jax.numpy as jnp
import brainpy.math as bm
@@ -9,135 +9,235 @@
class TestConv(bp.testing.UnitTestCase):
- def test_Conv2D_img(self):
- img = jnp.zeros((2, 200, 198, 4))
- for k in range(4):
- x = 30 + 60 * k
- y = 20 + 60 * k
- img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
- img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
-
- with bp.math.training_environment():
- net = bp.layers.Conv2d(in_channels=4, out_channels=32, kernel_size=(3, 3),
- strides=(1, 1), padding='SAME', groups=1)
- out = net(img)
- print("out shape: ", out.shape)
- # print("First output channel:")
- # plt.figure(figsize=(10, 10))
- # plt.imshow(np.array(img)[0, :, :, 0])
- # plt.show()
-
- def test_conv1D(self):
- with bp.math.training_environment():
- model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,))
- input = bp.math.ones((2, 5, 3))
-
- out = model(input)
- print("out shape: ", out.shape)
- # print("First output channel:")
- # plt.figure(figsize=(10, 10))
- # plt.imshow(np.array(out)[0, :, :])
- # plt.show()
-
- def test_conv2D(self):
- with bp.math.training_environment():
- model = bp.layers.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3))
-
- input = bp.math.ones((2, 5, 5, 3))
-
- out = model(input)
- print("out shape: ", out.shape)
- # print("First output channel:")
- # plt.figure(figsize=(10, 10))
- # plt.imshow(np.array(out)[0, :, :, 31])
- # plt.show()
-
- def test_conv3D(self):
- with bp.math.training_environment():
- model = bp.layers.Conv3d(in_channels=3, out_channels=32, kernel_size=(3, 3, 3))
- input = bp.math.ones((2, 5, 5, 5, 3))
- out = model(input)
- print("out shape: ", out.shape)
+ def test_Conv2D_img(self):
+ img = jnp.zeros((2, 200, 198, 4))
+ for k in range(4):
+ x = 30 + 60 * k
+ y = 20 + 60 * k
+ img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
+ img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
+
+ with bp.math.training_environment():
+ net = bp.layers.Conv2d(in_channels=4, out_channels=32, kernel_size=(3, 3),
+ strides=(2, 1), padding='VALID', groups=4)
+ out = net(img)
+ print("out shape: ", out.shape)
+ # print("First output channel:")
+ # plt.figure(figsize=(10, 10))
+ # plt.imshow(np.array(img)[0, :, :, 0])
+ # plt.show()
+
+ def test_conv1D(self):
+ with bp.math.training_environment():
+ model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,))
+
+ input = bp.math.ones((2, 5, 3))
+
+ out = model(input)
+ print("out shape: ", out.shape)
+ # print("First output channel:")
+ # plt.figure(figsize=(10, 10))
+ # plt.imshow(np.array(out)[0, :, :])
+ # plt.show()
+
+ def test_conv2D(self):
+ with bp.math.training_environment():
+ model = bp.layers.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3))
+
+ input = bp.math.ones((2, 5, 5, 3))
+
+ out = model(input)
+ print("out shape: ", out.shape)
+ # print("First output channel:")
+ # plt.figure(figsize=(10, 10))
+ # plt.imshow(np.array(out)[0, :, :, 31])
+ # plt.show()
+
+ def test_conv3D(self):
+ with bp.math.training_environment():
+ model = bp.layers.Conv3d(in_channels=3, out_channels=32, kernel_size=(3, 3, 3))
+ input = bp.math.ones((2, 5, 5, 5, 3))
+ out = model(input)
+ print("out shape: ", out.shape)
class TestConvTranspose1d(bp.testing.UnitTestCase):
- def test_conv_transpose(self):
- x = bm.ones((1, 8, 3))
- for use_bias in [True, False]:
- conv_transpose_module = bp.layers.ConvTranspose1d(
- in_channels=3,
- out_channels=4,
- kernel_size=(3,),
- padding='VALID',
- w_initializer=bp.init.OneInit(),
- b_initializer=bp.init.OneInit() if use_bias else None,
- mode=bm.training_mode
- )
- self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
- y = conv_transpose_module(x)
- print(y.shape)
- correct_ans = jnp.array([[[4., 4., 4., 4.],
- [7., 7., 7., 7.],
- [10., 10., 10., 10.],
- [10., 10., 10., 10.],
- [10., 10., 10., 10.],
- [10., 10., 10., 10.],
- [10., 10., 10., 10.],
- [10., 10., 10., 10.],
- [7., 7., 7., 7.],
- [4., 4., 4., 4.]]])
- if not use_bias:
- correct_ans -= 1.
- self.assertTrue(bm.allclose(y, correct_ans))
-
- def test_single_input_masked_conv_transpose(self):
- x = jnp.ones((1, 8, 3))
- m = jnp.tril(jnp.ones((3, 3, 4)))
- conv_transpose_module = bp.layers.ConvTranspose1d(
- in_channels=3,
- out_channels=4,
- kernel_size=(3,),
- padding='VALID',
- mask=m,
- w_initializer=bp.init.OneInit(),
- b_initializer=bp.init.OneInit(),
- mode=bm.batching_mode
- )
- self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
- y = conv_transpose_module(x)
- print(y.shape)
- correct_ans = jnp.array([[[4., 3., 2., 1.],
- [7., 5., 3., 1.],
- [10., 7., 4., 1.],
- [10., 7., 4., 1.],
- [10., 7., 4., 1.],
- [10., 7., 4., 1.],
- [10., 7., 4., 1.],
- [10., 7., 4., 1.],
- [7., 5., 3., 1.],
- [4., 3., 2., 1.]]])
- self.assertTrue(bm.allclose(y, correct_ans))
-
- def test_computation_padding_same(self):
- data = jnp.ones([1, 3, 1])
- for use_bias in [True, False]:
- net = bp.layers.ConvTranspose1d(
- in_channels=1,
- out_channels=1,
- kernel_size=3,
- stride=1,
- padding="SAME",
- w_initializer=bp.init.OneInit(),
- b_initializer=bp.init.OneInit() if use_bias else None,
- mode=bm.batching_mode
- )
- out = net(data)
- self.assertEqual(out.shape, (1, 3, 1))
- out = jnp.squeeze(out, axis=(0, 2))
- expected_out = bm.as_jax([2, 3, 2])
- if use_bias:
- expected_out += 1
- self.assertTrue(bm.allclose(out, expected_out, rtol=1e-5))
-
-
-
+ def test_conv_transpose(self):
+ x = bm.ones((1, 8, 3))
+ for use_bias in [True, False]:
+ conv_transpose_module = bp.layers.ConvTranspose1d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3,),
+ padding='VALID',
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit() if use_bias else None,
+ mode=bm.training_mode
+ )
+ self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
+ y = conv_transpose_module(x)
+ print(y.shape)
+ correct_ans = jnp.array([[[4., 4., 4., 4.],
+ [7., 7., 7., 7.],
+ [10., 10., 10., 10.],
+ [10., 10., 10., 10.],
+ [10., 10., 10., 10.],
+ [10., 10., 10., 10.],
+ [10., 10., 10., 10.],
+ [10., 10., 10., 10.],
+ [7., 7., 7., 7.],
+ [4., 4., 4., 4.]]])
+ if not use_bias:
+ correct_ans -= 1.
+ self.assertTrue(bm.allclose(y, correct_ans))
+
+ def test_single_input_masked_conv_transpose(self):
+ x = jnp.ones((1, 8, 3))
+ m = jnp.tril(jnp.ones((3, 3, 4)))
+ conv_transpose_module = bp.layers.ConvTranspose1d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3,),
+ padding='VALID',
+ mask=m,
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit(),
+ mode=bm.batching_mode
+ )
+ self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
+ y = conv_transpose_module(x)
+ print(y.shape)
+ correct_ans = jnp.array([[[4., 3., 2., 1.],
+ [7., 5., 3., 1.],
+ [10., 7., 4., 1.],
+ [10., 7., 4., 1.],
+ [10., 7., 4., 1.],
+ [10., 7., 4., 1.],
+ [10., 7., 4., 1.],
+ [10., 7., 4., 1.],
+ [7., 5., 3., 1.],
+ [4., 3., 2., 1.]]])
+ self.assertTrue(bm.allclose(y, correct_ans))
+
+ def test_computation_padding_same(self):
+ data = jnp.ones([1, 3, 1])
+ for use_bias in [True, False]:
+ net = bp.layers.ConvTranspose1d(
+ in_channels=1,
+ out_channels=1,
+ kernel_size=3,
+ stride=1,
+ padding="SAME",
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit() if use_bias else None,
+ mode=bm.batching_mode
+ )
+ out = net(data)
+ self.assertEqual(out.shape, (1, 3, 1))
+ out = jnp.squeeze(out, axis=(0, 2))
+ expected_out = bm.as_jax([2, 3, 2])
+ if use_bias:
+ expected_out += 1
+ self.assertTrue(bm.allclose(out, expected_out, rtol=1e-5))
+
+
+class TestConvTranspose2d(bp.testing.UnitTestCase):
+ def test_conv_transpose(self):
+ x = bm.ones((1, 8, 8, 3))
+ for use_bias in [True, False]:
+ conv_transpose_module = bp.layers.ConvTranspose2d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3, 3),
+ padding='VALID',
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit() if use_bias else None,
+ mode=bm.training_mode
+ )
+ self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
+ y = conv_transpose_module(x)
+ print(y.shape)
+
+ def test_single_input_masked_conv_transpose(self):
+ x = jnp.ones((1, 8, 8, 3))
+ m = jnp.tril(jnp.ones((3, 3, 3, 4)))
+ conv_transpose_module = bp.layers.ConvTranspose2d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3, 3),
+ padding='VALID',
+ mask=m,
+ w_initializer=bp.init.OneInit(),
+ mode=bm.training_mode
+ )
+ y = conv_transpose_module(x)
+ print(y.shape)
+
+ def test_computation_padding_same(self):
+ x = bm.ones((1, 8, 8, 3))
+ for use_bias in [True, False]:
+ conv_transpose_module = bp.layers.ConvTranspose2d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3, 3),
+ stride=1,
+ padding='SAME',
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit() if use_bias else None,
+ mode=bm.training_mode,
+ # mode=bm.nonbatching_mode,
+ )
+ y = conv_transpose_module(x)
+ print(y.shape)
+
+
+class TestConvTranspose3d(bp.testing.UnitTestCase):
+ def test_conv_transpose(self):
+ x = bm.ones((1, 8, 8, 8, 3))
+ for use_bias in [True, False]:
+ conv_transpose_module = bp.layers.ConvTranspose3d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3, 3, 3),
+ # padding='VALID',
+ # w_initializer=bp.init.OneInit(),
+ # b_initializer=bp.init.OneInit() if use_bias else None,
+ mode=bm.training_mode
+ )
+ y = conv_transpose_module(x)
+ print(y.shape)
+
+ def test_single_input_masked_conv_transpose(self):
+ x = jnp.ones((1, 8, 8, 8, 3))
+ m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
+ conv_transpose_module = bp.layers.ConvTranspose3d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3, 3, 3),
+ padding='VALID',
+ mask=m,
+ w_initializer=bp.init.OneInit(),
+ mode=bm.training_mode
+ )
+ y = conv_transpose_module(x)
+ print(y.shape)
+
+ def test_computation_padding_same(self):
+ x = bm.ones((1, 8, 8, 8, 3))
+ for use_bias in [True, False]:
+ conv_transpose_module = bp.layers.ConvTranspose3d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3, 3, 3),
+ stride=1,
+ padding='SAME',
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit() if use_bias else None,
+ mode=bm.training_mode
+ )
+ y = conv_transpose_module(x)
+ print(y.shape)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py
new file mode 100644
index 000000000..b51efe16f
--- /dev/null
+++ b/brainpy/_src/dnn/tests/test_function.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+
+from unittest import TestCase
+
+import jax.numpy as jnp
+import brainpy.math as bm
+from absl.testing import absltest
+import brainpy as bp
+
+
+class TestFunction(bp.testing.UnitTestCase):
+
+ def test_flatten_batching_mode(self):
+ layer = bp.dnn.Flatten(mode=bm.BatchingMode())
+ input = bm.random.randn(20, 10, 10, 6)
+
+ output = layer.update(input)
+
+ expected_shape = (20, 600)
+ self.assertEqual(output.shape, expected_shape)
+
+ def test_flatten_non_batching_mode(self):
+ layer = bp.dnn.Flatten(mode=bm.NonBatchingMode())
+ input = bm.random.randn(10, 10, 6)
+
+ output = layer.update(input)
+
+ expected_shape = (600,)
+ self.assertEqual(output.shape, expected_shape)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py
index 337536fd2..5ce07d474 100644
--- a/brainpy/_src/dnn/tests/test_linear.py
+++ b/brainpy/_src/dnn/tests/test_linear.py
@@ -1,6 +1,6 @@
import brainpy as bp
from absl.testing import parameterized
-
+from absl.testing import absltest
import brainpy.math as bm
@@ -93,6 +93,24 @@ def test_CSRLinear(self, conn):
y = f(x)
self.assertTrue(y.shape == (100,))
+
+ @parameterized.product(
+ conn=[
+ bp.conn.FixedProb(0.1, pre=100, post=100),
+ bp.conn.GridFour(pre=100, post=100),
+ bp.conn.GaussianProb(0.1, pre=100, post=100),
+ ]
+ )
+ def test_EventCSRLinear(self,conn):
+ f=bp.layers.EventCSRLinear(conn,weight=bp.init.Normal())
+ x = bm.random.random((16, 100))
+ y = f(x)
+ self.assertTrue(y.shape == (16, 100))
+ x = bm.random.random((100,))
+ y = f(x)
+ self.assertTrue(y.shape == (100,))
+
+
@parameterized.product(
prob=[0.01, 0.05, 0.5],
weight=[0.01, 0.01],
@@ -170,3 +188,5 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
self.assertTrue(y2.shape == shape + (200,))
+if __name__ == '__main__':
+ absltest.main()
diff --git a/brainpy/_src/dnn/tests/test_normalization.py b/brainpy/_src/dnn/tests/test_normalization.py
new file mode 100644
index 000000000..a93a64de0
--- /dev/null
+++ b/brainpy/_src/dnn/tests/test_normalization.py
@@ -0,0 +1,57 @@
+import brainpy.math as bm
+from absl.testing import parameterized
+from absl.testing import absltest
+import brainpy as bp
+
+
+class Test_Normalization(parameterized.TestCase):
+ @parameterized.product(
+ fit=[True, False],
+ )
+ def test_BatchNorm1d(self, fit):
+ net = bp.dnn.BatchNorm1d(num_features=10, mode=bm.training_mode)
+ bp.share.save(fit=fit)
+ input = bm.random.randn(1, 3, 10)
+ output = net(input)
+
+ @parameterized.product(
+ fit=[True, False]
+ )
+ def test_BatchNorm2d(self, fit):
+ net = bp.dnn.BatchNorm2d(10, mode=bm.training_mode)
+ bp.share.save(fit=fit)
+ input = bm.random.randn(1, 3, 4, 10)
+ output = net(input)
+
+ @parameterized.product(
+ fit=[True, False]
+ )
+ def test_BatchNorm3d(self, fit):
+ net = bp.dnn.BatchNorm3d(10, mode=bm.training_mode)
+ bp.share.save(fit=fit)
+ input = bm.random.randn(1, 3, 4, 5, 10)
+ output = net(input)
+
+ @parameterized.product(
+ normalized_shape=(10, [5, 10])
+ )
+ def test_LayerNorm(self, normalized_shape):
+ net = bp.dnn.LayerNorm(normalized_shape, mode=bm.training_mode)
+ input = bm.random.randn(20, 5, 10)
+ output = net(input)
+
+ @parameterized.product(
+ num_groups=[1, 2, 3, 6]
+ )
+ def test_GroupNorm(self, num_groups):
+ input = bm.random.randn(20, 10, 10, 6)
+ net = bp.dnn.GroupNorm(num_groups=num_groups, num_channels=6, mode=bm.training_mode)
+ output = net(input)
+
+ def test_InstanceNorm(self):
+ input = bm.random.randn(20, 10, 10, 6)
+ net = bp.dnn.InstanceNorm(num_channels=6, mode=bm.training_mode)
+ output = net(input)
+
+if __name__ == '__main__':
+ absltest.main()
\ No newline at end of file
diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py
index 6367bbc95..b05932cb3 100644
--- a/brainpy/_src/dnn/tests/test_pooling_layers.py
+++ b/brainpy/_src/dnn/tests/test_pooling_layers.py
@@ -4,149 +4,216 @@
import jax.numpy as jnp
import numpy as np
from absl.testing import parameterized
+from absl.testing import absltest
import brainpy as bp
import brainpy.math as bm
class TestPool(parameterized.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- self.rng = bm.random.default_rng(12345)
-
- def test_maxpool(self):
- x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
- print(jnp.arange(9).reshape(3, 3))
- print(x)
- print(x.shape)
- shared = {'fit': False}
- with bm.training_environment():
- net = bp.layers.MaxPool((2, 2), 1, channel_axis=-1)
- y = net(shared, x)
- print("out shape: ", y.shape)
- expected_y = jnp.array([[4., 5.],
- [7., 8.]]).reshape((1, 2, 2, 1))
- np.testing.assert_allclose(y, expected_y)
-
- def test_maxpool2(self):
- x = self.rng.rand(10, 20, 20, 4)
- with bm.training_environment():
- net = bp.layers.MaxPool((2, 2), (2, 2), channel_axis=-1)
- y = net(x)
- print("out shape: ", y.shape)
-
- def test_minpool(self):
- x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
- shared = {'fit': False}
- with bm.training_environment():
- net = bp.layers.MinPool((2, 2), 1, channel_axis=-1)
- y = net(shared, x)
- print("out shape: ", y.shape)
- expected_y = jnp.array([
- [0., 1.],
- [3., 4.],
- ]).reshape((1, 2, 2, 1))
- np.testing.assert_allclose(y, expected_y)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.rng = bm.random.default_rng(12345)
+
+ def test_maxpool(self):
+ x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
+ print(jnp.arange(9).reshape(3, 3))
+ print(x)
+ print(x.shape)
+ shared = {'fit': False}
+ with bm.training_environment():
+ net = bp.dnn.MaxPool((2, 2), 1, channel_axis=-1)
+ y = net(shared, x)
+ print("out shape: ", y.shape)
+ expected_y = jnp.array([[4., 5.],
+ [7., 8.]]).reshape((1, 2, 2, 1))
+ np.testing.assert_allclose(y, expected_y)
+
+ def test_maxpool2(self):
+ x = self.rng.rand(10, 20, 20, 4)
+ with bm.training_environment():
+ net = bp.dnn.MaxPool((2, 2), (2, 2), channel_axis=-1)
+ y = net(x)
+ print("out shape: ", y.shape)
+
+ def test_minpool(self):
+ x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
+ shared = {'fit': False}
+ with bm.training_environment():
+ net = bp.dnn.MinPool((2, 2), 1, channel_axis=-1)
+ y = net(shared, x)
+ print("out shape: ", y.shape)
+ expected_y = jnp.array([
+ [0., 1.],
+ [3., 4.],
+ ]).reshape((1, 2, 2, 1))
+ np.testing.assert_allclose(y, expected_y)
- def test_avgpool(self):
- x = jnp.full((1, 3, 3, 1), 2.)
- with bm.training_environment():
- net = bp.layers.AvgPool((2, 2), 1, channel_axis=-1)
- y = net(x)
- print("out shape: ", y.shape)
- np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))
-
- def test_MaxPool2d_v1(self):
- arr = self.rng.rand(16, 32, 32, 8)
-
- out = bp.layers.MaxPool2d(2, 2, channel_axis=-1)(arr)
- self.assertTrue(out.shape == (16, 16, 16, 8))
-
- out = bp.layers.MaxPool2d(2, 2, channel_axis=None)(arr)
- self.assertTrue(out.shape == (16, 32, 16, 4))
-
- out = bp.layers.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr)
- self.assertTrue(out.shape == (16, 32, 17, 5))
-
- out = bp.layers.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
- self.assertTrue(out.shape == (16, 32, 18, 5))
-
- out = bp.layers.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
- self.assertTrue(out.shape == (16, 17, 17, 8))
-
- out = bp.layers.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
- self.assertTrue(out.shape == (16, 17, 32, 5))
-
- def test_AvgPool2d_v1(self):
- arr = self.rng.rand(16, 32, 32, 8)
-
- out = bp.layers.AvgPool2d(2, 2, channel_axis=-1)(arr)
- self.assertTrue(out.shape == (16, 16, 16, 8))
-
- out = bp.layers.AvgPool2d(2, 2, channel_axis=None)(arr)
- self.assertTrue(out.shape == (16, 32, 16, 4))
-
- out = bp.layers.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr)
- self.assertTrue(out.shape == (16, 32, 17, 5))
-
- out = bp.layers.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
- self.assertTrue(out.shape == (16, 32, 18, 5))
-
- out = bp.layers.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
- self.assertTrue(out.shape == (16, 17, 17, 8))
-
- out = bp.layers.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
- self.assertTrue(out.shape == (16, 17, 32, 5))
-
- @parameterized.named_parameters(
- dict(testcase_name=f'target_size={target_size}',
- target_size=target_size)
- for target_size in [10, 9, 8, 7, 6]
- )
- def test_adaptive_pool1d(self, target_size):
- from brainpy._src.dnn.pooling import _adaptive_pool1d
-
- arr = self.rng.rand(100)
- op = jax.numpy.mean
-
- out = _adaptive_pool1d(arr, target_size, op)
- print(out.shape)
- self.assertTrue(out.shape == (target_size,))
-
- out = _adaptive_pool1d(arr, target_size, op)
- print(out.shape)
- self.assertTrue(out.shape == (target_size,))
-
- def test_AdaptiveAvgPool2d_v1(self):
- input = self.rng.randn(64, 8, 9)
-
- output = bp.layers.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
- self.assertTrue(output.shape == (64, 5, 7))
-
- output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
- self.assertTrue(output.shape == (64, 2, 3))
-
- output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
- self.assertTrue(output.shape == (2, 3, 9))
-
- output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
- self.assertTrue(output.shape == (2, 8, 3))
-
- output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=None)(input)
- self.assertTrue(output.shape == (64, 2, 3))
-
- def test_AdaptiveAvgPool2d_v2(self):
- input = self.rng.randn(128, 64, 32, 16)
-
- output = bp.layers.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
- self.assertTrue(output.shape == (128, 64, 5, 7))
-
- output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
- self.assertTrue(output.shape == (128, 64, 2, 3))
-
- output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
- self.assertTrue(output.shape == (128, 2, 3, 16))
-
- output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
- self.assertTrue(output.shape == (128, 64, 2, 3))
+ def test_avgpool(self):
+ x = jnp.full((1, 3, 3, 1), 2.)
+ with bm.training_environment():
+ net = bp.dnn.AvgPool((2, 2), 1, channel_axis=-1)
+ y = net(x)
+ print("out shape: ", y.shape)
+ np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))
+
+ def test_MaxPool2d_v1(self):
+ arr = self.rng.rand(16, 32, 32, 8)
+
+ out = bp.dnn.MaxPool2d(2, 2, channel_axis=-1)(arr)
+ self.assertTrue(out.shape == (16, 16, 16, 8))
+
+ out = bp.dnn.MaxPool2d(2, 2, channel_axis=None)(arr)
+ self.assertTrue(out.shape == (16, 32, 16, 4))
+
+ out = bp.dnn.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr)
+ self.assertTrue(out.shape == (16, 32, 17, 5))
+
+ out = bp.dnn.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
+ self.assertTrue(out.shape == (16, 32, 18, 5))
+
+ out = bp.dnn.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
+ self.assertTrue(out.shape == (16, 17, 17, 8))
+
+ out = bp.dnn.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
+ self.assertTrue(out.shape == (16, 17, 32, 5))
+
+ def test_AvgPool2d_v1(self):
+ arr = self.rng.rand(16, 32, 32, 8)
+
+ out = bp.dnn.AvgPool2d(2, 2, channel_axis=-1)(arr)
+ self.assertTrue(out.shape == (16, 16, 16, 8))
+
+ out = bp.dnn.AvgPool2d(2, 2, channel_axis=None)(arr)
+ self.assertTrue(out.shape == (16, 32, 16, 4))
+
+ out = bp.dnn.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr)
+ self.assertTrue(out.shape == (16, 32, 17, 5))
+
+ out = bp.dnn.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
+ self.assertTrue(out.shape == (16, 32, 18, 5))
+
+ out = bp.dnn.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
+ self.assertTrue(out.shape == (16, 17, 17, 8))
+
+ out = bp.dnn.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
+ self.assertTrue(out.shape == (16, 17, 32, 5))
+
+ @parameterized.named_parameters(
+ dict(testcase_name=f'target_size={target_size}',
+ target_size=target_size)
+ for target_size in [10, 9, 8, 7, 6]
+ )
+ def test_adaptive_pool1d(self, target_size):
+ from brainpy._src.dnn.pooling import _adaptive_pool1d
+
+ arr = self.rng.rand(100)
+ op = jax.numpy.mean
+
+ out = _adaptive_pool1d(arr, target_size, op)
+ print(out.shape)
+ self.assertTrue(out.shape == (target_size,))
+
+ out = _adaptive_pool1d(arr, target_size, op)
+ print(out.shape)
+ self.assertTrue(out.shape == (target_size,))
+
+ def test_AdaptiveAvgPool2d_v1(self):
+ input = self.rng.randn(64, 8, 9)
+
+ output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
+ self.assertTrue(output.shape == (64, 5, 7))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
+ self.assertTrue(output.shape == (64, 2, 3))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
+ self.assertTrue(output.shape == (2, 3, 9))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
+ self.assertTrue(output.shape == (2, 8, 3))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=None)(input)
+ self.assertTrue(output.shape == (64, 2, 3))
+
+ def test_AdaptiveAvgPool2d_v2(self):
+ input = self.rng.randn(128, 64, 32, 16)
+
+ output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
+ self.assertTrue(output.shape == (128, 64, 5, 7))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
+ self.assertTrue(output.shape == (128, 64, 2, 3))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
+ self.assertTrue(output.shape == (128, 2, 3, 16))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
+ self.assertTrue(output.shape == (128, 64, 2, 3))
+ print()
+
+ def test_AdaptiveAvgPool3d_v1(self):
+ input = bm.random.randn(10, 128, 64, 32)
+ net = bp.dnn.AdaptiveAvgPool3d(target_shape=[6, 5, 3], channel_axis=0, mode=bm.nonbatching_mode)
+ output = net(input)
+ self.assertTrue(output.shape == (10, 6, 5, 3))
+
+ def test_AdaptiveAvgPool3d_v2(self):
+ input = bm.random.randn(10, 20, 128, 64, 32)
+ net = bp.dnn.AdaptiveAvgPool3d(target_shape=[6, 5, 3], mode=bm.batching_mode)
+ output = net(input)
+ self.assertTrue(output.shape == (10, 6, 5, 3, 32))
+
+ @parameterized.product(
+ axis=(-1, 0, 1)
+ )
+ def test_AdaptiveMaxPool1d_v1(self, axis):
+ input = bm.random.randn(32, 16)
+ net = bp.dnn.AdaptiveMaxPool1d(target_shape=4, channel_axis=axis)
+ output = net(input)
+
+ @parameterized.product(
+ axis=(-1, 0, 1, 2)
+ )
+ def test_AdaptiveMaxPool1d_v2(self, axis):
+ input = bm.random.randn(2, 32, 16)
+ net = bp.dnn.AdaptiveMaxPool1d(target_shape=4, channel_axis=axis)
+ output = net(input)
+
+ @parameterized.product(
+ axis=(-1, 0, 1, 2)
+ )
+ def test_AdaptiveMaxPool2d_v1(self, axis):
+ input = bm.random.randn(32, 16, 12)
+ net = bp.dnn.AdaptiveAvgPool2d(target_shape=[5, 4], channel_axis=axis)
+ output = net(input)
+
+ @parameterized.product(
+ axis=(-1, 0, 1, 2, 3)
+ )
+ def test_AdaptiveMaxPool2d_v2(self, axis):
+ input = bm.random.randn(2, 32, 16, 12)
+ net = bp.dnn.AdaptiveAvgPool2d(target_shape=[5, 4], channel_axis=axis)
+ # output = net(input)
+
+ @parameterized.product(
+ axis=(-1, 0, 1, 2, 3)
+ )
+ def test_AdaptiveMaxPool3d_v1(self, axis):
+ input = bm.random.randn(2, 128, 64, 32)
+ net = bp.dnn.AdaptiveMaxPool3d(target_shape=[6, 5, 4], channel_axis=axis)
+ output = net(input)
+ print()
+
+ @parameterized.product(
+ axis=(-1, 0, 1, 2, 3, 4)
+ )
+ def test_AdaptiveMaxPool3d_v1(self, axis):
+ input = bm.random.randn(2, 128, 64, 32, 16)
+ net = bp.dnn.AdaptiveMaxPool3d(target_shape=[6, 5, 4], channel_axis=axis)
+ output = net(input)
+
+
+if __name__ == '__main__':
+ absltest.main()
From 8d94479440afa719aecc7b4ae6887412650f45cd Mon Sep 17 00:00:00 2001
From: GYF <1337838189@qq.com>
Date: Tue, 11 Jul 2023 16:55:28 +0800
Subject: [PATCH 026/326] fix conflicts
---
brainpy/_src/dnn/conv.py | 2 +-
brainpy/_src/dnn/tests/test_conv_layers.py | 6 +++---
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py
index f045648c4..8bcd8d720 100644
--- a/brainpy/_src/dnn/conv.py
+++ b/brainpy/_src/dnn/conv.py
@@ -711,7 +711,7 @@ def __init__(
name: The name of the module.
"""
super().__init__(
- num_spatial_dims=1,
+ num_spatial_dims=3,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py
index 71e63682f..828b06496 100644
--- a/brainpy/_src/dnn/tests/test_conv_layers.py
+++ b/brainpy/_src/dnn/tests/test_conv_layers.py
@@ -199,9 +199,9 @@ def test_conv_transpose(self):
in_channels=3,
out_channels=4,
kernel_size=(3, 3, 3),
- # padding='VALID',
- # w_initializer=bp.init.OneInit(),
- # b_initializer=bp.init.OneInit() if use_bias else None,
+ padding='VALID',
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit() if use_bias else None,
mode=bm.training_mode
)
y = conv_transpose_module(x)
From dcab68ba7253dc75c846df40a5d137a797a9196d Mon Sep 17 00:00:00 2001
From: chaoming
Date: Tue, 11 Jul 2023 21:29:56 +0800
Subject: [PATCH 027/326] standardize ion channels modeling
---
brainpy/_src/dyn/channels/__init__.py | 26 +-
brainpy/_src/dyn/channels/base.py | 55 +-
.../_src/dyn/channels/{Ca.py => calcium.py} | 197 +-
.../{IH.py => hyperpolarization_activated.py} | 30 +-
brainpy/_src/dyn/channels/leaky.py | 52 +-
brainpy/_src/dyn/channels/potassium.py | 2038 +++++++++++++++++
.../_src/dyn/channels/potassium_calcium.py | 128 ++
...KCa.py => potassium_calcium_compatible.py} | 9 +-
.../{K.py => potassium_compatible.py} | 111 +-
brainpy/_src/dyn/channels/sodium.py | 381 +++
.../channels/{Na.py => sodium_compatible.py} | 18 +-
brainpy/_src/dyn/ions/__init__.py | 4 +-
brainpy/_src/dyn/ions/base.py | 202 +-
brainpy/_src/dyn/ions/{ca.py => calcium.py} | 86 +-
brainpy/_src/dyn/ions/potassium.py | 52 +
brainpy/_src/dyn/ions/sodium.py | 52 +
brainpy/_src/dyn/ions/tests/test_MixIons.py | 98 +
brainpy/dyn/channels.py | 43 +-
brainpy/dyn/ions.py | 20 +-
brainpy/dyn/neurons.py | 1 +
20 files changed, 3244 insertions(+), 359 deletions(-)
rename brainpy/_src/dyn/channels/{Ca.py => calcium.py} (85%)
rename brainpy/_src/dyn/channels/{IH.py => hyperpolarization_activated.py} (94%)
create mode 100644 brainpy/_src/dyn/channels/potassium.py
create mode 100644 brainpy/_src/dyn/channels/potassium_calcium.py
rename brainpy/_src/dyn/channels/{KCa.py => potassium_calcium_compatible.py} (96%)
rename brainpy/_src/dyn/channels/{K.py => potassium_compatible.py} (93%)
create mode 100644 brainpy/_src/dyn/channels/sodium.py
rename brainpy/_src/dyn/channels/{Na.py => sodium_compatible.py} (96%)
rename brainpy/_src/dyn/ions/{ca.py => calcium.py} (84%)
create mode 100644 brainpy/_src/dyn/ions/potassium.py
create mode 100644 brainpy/_src/dyn/ions/sodium.py
create mode 100644 brainpy/_src/dyn/ions/tests/test_MixIons.py
diff --git a/brainpy/_src/dyn/channels/__init__.py b/brainpy/_src/dyn/channels/__init__.py
index 326e68b12..4d43a4d2a 100644
--- a/brainpy/_src/dyn/channels/__init__.py
+++ b/brainpy/_src/dyn/channels/__init__.py
@@ -1,25 +1,9 @@
# -*- coding: utf-8 -*-
-"""
-
-Access through ``brainpy.channels``.
-"""
-
-from . import base, Ca, IH, K, Na, KCa, leaky
-
-__all__ = []
-__all__ += base.__all__
-__all__ += K.__all__
-__all__ += Na.__all__
-__all__ += Ca.__all__
-__all__ += IH.__all__
-__all__ += KCa.__all__
-__all__ += leaky.__all__
-
from .base import *
-from .K import *
-from .Na import *
-from .IH import *
-from .Ca import *
-from .KCa import *
+from .potassium import *
+from .sodium import *
+from .hyperpolarization_activated import *
+from .calcium import *
+from .potassium_calcium import *
from .leaky import *
diff --git a/brainpy/_src/dyn/channels/base.py b/brainpy/_src/dyn/channels/base.py
index 863bbd7d4..b933930a0 100644
--- a/brainpy/_src/dyn/channels/base.py
+++ b/brainpy/_src/dyn/channels/base.py
@@ -2,11 +2,10 @@
from brainpy._src.dyn.base import IonChaDyn
from brainpy._src.mixin import TreeNode
-from brainpy._src.dyn.ions.base import Calcium
from brainpy._src.dyn.neurons.hh import HHTypedNeuron
__all__ = [
- 'IonChannel', 'IhChannel', 'CalciumChannel', 'SodiumChannel', 'PotassiumChannel', 'LeakyChannel',
+ 'IonChannel',
]
@@ -16,16 +15,13 @@ class IonChannel(IonChaDyn, TreeNode):
'''The type of the master object.'''
master_type = HHTypedNeuron
- def update(self, V):
+ def update(self, *args, **kwargs):
raise NotImplementedError('Must be implemented by the subclass.')
- def current(self, V):
+ def current(self, *args, **kwargs):
raise NotImplementedError('Must be implemented by the subclass.')
- def reset(self, V, batch_size=None):
- self.reset_state(V, batch_size)
-
- def reset_state(self, V, batch_size=None):
+ def reset_state(self, *args, **kwargs):
raise NotImplementedError('Must be implemented by the subclass.')
def clear_input(self):
@@ -33,46 +29,3 @@ def clear_input(self):
def __repr__(self):
return f'{self.name}(size={self.size})'
-
-
-class CalciumChannel(IonChannel):
- """Base class for Calcium ion channels."""
-
- master_type = Calcium
- '''The type of the master object.'''
-
- def update(self, V, C_Ca, E_Ca):
- raise NotImplementedError
-
- def current(self, V, C_Ca, E_Ca):
- raise NotImplementedError
-
- def reset(self, V, C_Ca, E_Ca, batch_size: int = None):
- self.reset_state(V, C_Ca, E_Ca, batch_size)
-
- def reset_state(self, V, C_Ca, E_Ca, batch_size: int = None):
- raise NotImplementedError('Must be implemented by the subclass.')
-
-
-class IhChannel(IonChannel):
- """Base class for Ih channel models."""
- master_type = HHTypedNeuron
-
-
-class PotassiumChannel(IonChannel):
- """Base class for potassium channel dynamics."""
-
- '''The type of the master object.'''
- master_type = HHTypedNeuron
-
-
-class LeakyChannel(IonChannel):
- """Base class for leaky channel dynamics."""
-
- master_type = HHTypedNeuron
-
-
-class SodiumChannel(IonChannel):
- """Base class for sodium channel dynamics."""
-
- master_type = HHTypedNeuron
diff --git a/brainpy/_src/dyn/channels/Ca.py b/brainpy/_src/dyn/channels/calcium.py
similarity index 85%
rename from brainpy/_src/dyn/channels/Ca.py
rename to brainpy/_src/dyn/channels/calcium.py
index 91c532910..3d8a04ef9 100644
--- a/brainpy/_src/dyn/channels/Ca.py
+++ b/brainpy/_src/dyn/channels/calcium.py
@@ -5,30 +5,45 @@
"""
-from typing import Union, Callable
+from typing import Union, Callable, Optional
import brainpy.math as bm
from brainpy._src.context import share
-from brainpy._src.dyn.ions.ca import CalciumDyna
+from brainpy._src.dyn.ions.calcium import Calcium, CalciumDyna
from brainpy._src.initialize import Initializer, parameter, variable
from brainpy._src.integrators.joint_eq import JointEq
from brainpy._src.integrators.ode.generic import odeint
from brainpy.types import Shape, ArrayType
-from .base import CalciumChannel
+from .base import IonChannel
__all__ = [
- 'ICaN_IS2008',
+ 'CalciumChannel',
+ 'ICaN_IS2008',
'ICaT_HM1992',
'ICaT_HP1992',
-
'ICaHT_HM1992',
-
'ICaL_IS2008',
]
-# -------------------------
+class CalciumChannel(IonChannel):
+ """Base class for Calcium ion channels."""
+
+ master_type = Calcium
+ '''The type of the master object.'''
+
+ def update(self, V, C, E):
+ raise NotImplementedError
+
+ def current(self, V, C, E):
+ raise NotImplementedError
+
+ def reset(self, V, C, E, batch_size: int = None):
+ self.reset_state(V, C, E, batch_size)
+
+ def reset_state(self, V, C, E, batch_size: int = None):
+ raise NotImplementedError('Must be implemented by the subclass.')
class _ICa_p2q_ss(CalciumChannel):
@@ -72,13 +87,13 @@ def __init__(
phi_q: Union[float, ArrayType, Initializer, Callable] = 3.,
g_max: Union[float, ArrayType, Initializer, Callable] = 2.,
method: str = 'exp_auto',
- mode: bm.Mode = None,
- name: str = None
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None
):
- super(_ICa_p2q_ss, self).__init__(size,
- keep_size=keep_size,
- name=name,
- mode=mode, )
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode, )
# parameters
self.phi_p = parameter(phi_p, self.varshape, allow_none=False)
@@ -98,13 +113,13 @@ def dp(self, p, t, V):
def dq(self, q, t, V):
return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V)
- def update(self, V, C_Ca, E_Ca):
+ def update(self, V, C, E):
self.p.value, self.q.value = self.integral(self.p, self.q, share['t'], V, share['dt'])
- def current(self, V, C_Ca, E_Ca):
- return self.g_max * self.p * self.p * self.q * (E_Ca - V)
+ def current(self, V, C, E):
+ return self.g_max * self.p * self.p * self.q * (E - V)
- def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
+ def reset_state(self, V, C, E, batch_size=None):
self.p.value = self.f_p_inf(V)
self.q.value = self.f_q_inf(V)
if batch_size is not None:
@@ -165,13 +180,13 @@ def __init__(
phi_q: Union[float, ArrayType, Initializer, Callable] = 3.,
g_max: Union[float, ArrayType, Initializer, Callable] = 2.,
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
):
- super(_ICa_p2q_markov, self).__init__(size,
- keep_size=keep_size,
- name=name,
- mode=mode)
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
# parameters
self.phi_p = parameter(phi_p, self.varshape, allow_none=False)
@@ -191,13 +206,13 @@ def dp(self, p, t, V):
def dq(self, q, t, V):
return self.phi_q * (self.f_q_alpha(V) * (1 - q) - self.f_q_beta(V) * q)
- def update(self, V, C_Ca, E_Ca):
+ def update(self, V, C, E):
self.p.value, self.q.value = self.integral(self.p, self.q, share['t'], V, share['dt'])
- def current(self, V, C_Ca, E_Ca):
- return self.g_max * self.p * self.p * self.q * (E_Ca - V)
+ def current(self, V, C, E):
+ return self.g_max * self.p * self.p * self.q * (E - V)
- def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
+ def reset_state(self, V, C, E, batch_size=None):
alpha, beta = self.f_p_alpha(V), self.f_p_beta(V)
self.p.value = alpha / (alpha + beta)
alpha, beta = self.f_q_alpha(V), self.f_q_beta(V)
@@ -267,13 +282,13 @@ def __init__(
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
phi: Union[float, ArrayType, Initializer, Callable] = 1.,
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
):
- super(ICaN_IS2008, self).__init__(size,
- keep_size=keep_size,
- name=name,
- mode=mode)
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
# parameters
self.E = parameter(E, self.varshape, allow_none=False)
@@ -291,15 +306,15 @@ def derivative(self, p, t, V):
p_inf = 2.7 / (bm.exp(-(V + 55.) / 15.) + bm.exp((V + 55.) / 15.)) + 1.6
return self.phi * (phi_p - p) / p_inf
- def update(self, V, C_Ca, E_Ca):
+ def update(self, V, C, E):
self.p.value = self.integral(self.p.value, share['t'], V, share['dt'])
- def current(self, V, C_Ca, E_Ca):
- M = C_Ca / (C_Ca + 0.2)
+ def current(self, V, C, E):
+ M = C / (C + 0.2)
g = self.g_max * M * self.p
return g * (self.E - V)
- def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
+ def reset_state(self, V, C, E, batch_size=None):
self.p.value = 1.0 / (1 + bm.exp(-(V + 43.) / 5.2))
if batch_size is not None:
assert self.p.shape[0] == batch_size
@@ -365,19 +380,19 @@ def __init__(
phi_p: Union[float, ArrayType, Initializer, Callable] = None,
phi_q: Union[float, ArrayType, Initializer, Callable] = None,
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
):
phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p
phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q
- super(ICaT_HM1992, self).__init__(size,
- keep_size=keep_size,
- name=name,
- method=method,
- g_max=g_max,
- phi_p=phi_p,
- phi_q=phi_q,
- mode=mode)
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ g_max=g_max,
+ phi_p=phi_p,
+ phi_q=phi_q,
+ mode=mode)
# parameters
self.T = parameter(T, self.varshape, allow_none=False)
@@ -397,8 +412,8 @@ def f_q_inf(self, V):
def f_q_tau(self, V):
return bm.where(V >= (-80. + self.V_sh),
- bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28.,
- bm.exp((V + 467. - self.V_sh) / 66.6))
+ bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28.,
+ bm.exp((V + 467. - self.V_sh) / 66.6))
class ICaT_HP1992(_ICa_p2q_ss):
@@ -463,19 +478,19 @@ def __init__(
phi_p: Union[float, ArrayType, Initializer, Callable] = None,
phi_q: Union[float, ArrayType, Initializer, Callable] = None,
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
):
phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p
phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q
- super(ICaT_HP1992, self).__init__(size,
- keep_size=keep_size,
- name=name,
- method=method,
- g_max=g_max,
- phi_p=phi_p,
- phi_q=phi_q,
- mode=mode)
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ g_max=g_max,
+ phi_p=phi_p,
+ phi_q=phi_q,
+ mode=mode)
# parameters
self.T = parameter(T, self.varshape, allow_none=False)
@@ -556,17 +571,17 @@ def __init__(
g_max: Union[float, ArrayType, Initializer, Callable] = 2.,
V_sh: Union[float, ArrayType, Initializer, Callable] = 25.,
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
):
- super(ICaHT_HM1992, self).__init__(size,
- keep_size=keep_size,
- name=name,
- method=method,
- g_max=g_max,
- phi_p=T_base_p ** ((T - 24) / 10),
- phi_q=T_base_q ** ((T - 24) / 10),
- mode=mode)
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ g_max=g_max,
+ phi_p=T_base_p ** ((T - 24) / 10),
+ phi_q=T_base_q ** ((T - 24) / 10),
+ mode=mode)
# parameters
self.T = parameter(T, self.varshape, allow_none=False)
@@ -593,8 +608,8 @@ def f_q_inf(self, V):
def f_q_tau(self, V):
return bm.where(V >= (-80. + self.V_sh),
- bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28.,
- bm.exp((V + 467. - self.V_sh) / 66.6))
+ bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28.,
+ bm.exp((V + 467. - self.V_sh) / 66.6))
class ICaHT_Re1993(_ICa_p2q_markov):
@@ -663,19 +678,19 @@ def __init__(
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
V_sh: Union[float, ArrayType, Initializer, Callable] = 0.,
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
):
phi_p = T_base_p ** ((T - 23.) / 10.) if phi_p is None else phi_p
phi_q = T_base_q ** ((T - 23.) / 10.) if phi_q is None else phi_q
- super(ICaHT_Re1993, self).__init__(size,
- keep_size=keep_size,
- name=name,
- method=method,
- g_max=g_max,
- phi_p=phi_p,
- phi_q=phi_q,
- mode=mode)
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ g_max=g_max,
+ phi_p=phi_p,
+ phi_q=phi_q,
+ mode=mode)
self.T = parameter(T, self.varshape, allow_none=False)
self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False)
self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False)
@@ -750,17 +765,17 @@ def __init__(
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
V_sh: Union[float, ArrayType, Initializer, Callable] = 0.,
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
):
- super(ICaL_IS2008, self).__init__(size,
- keep_size=keep_size,
- name=name,
- method=method,
- g_max=g_max,
- phi_p=T_base_p ** ((T - 24) / 10),
- phi_q=T_base_q ** ((T - 24) / 10),
- mode=mode)
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ g_max=g_max,
+ phi_p=T_base_p ** ((T - 24) / 10),
+ phi_q=T_base_q ** ((T - 24) / 10),
+ mode=mode)
# parameters
self.T = parameter(T, self.varshape, allow_none=False)
diff --git a/brainpy/_src/dyn/channels/IH.py b/brainpy/_src/dyn/channels/hyperpolarization_activated.py
similarity index 94%
rename from brainpy/_src/dyn/channels/IH.py
rename to brainpy/_src/dyn/channels/hyperpolarization_activated.py
index 708723a3b..89c75eea1 100644
--- a/brainpy/_src/dyn/channels/IH.py
+++ b/brainpy/_src/dyn/channels/hyperpolarization_activated.py
@@ -2,18 +2,18 @@
"""
This module implements hyperpolarization-activated cation channels.
-
"""
-from typing import Union, Callable
+from typing import Union, Callable, Optional
import brainpy.math as bm
from brainpy._src.context import share
-from brainpy._src.dyn.ions.base import Calcium
+from brainpy._src.dyn.ions.calcium import Calcium
+from brainpy._src.dyn.neurons.hh import HHTypedNeuron
from brainpy._src.initialize import Initializer, parameter, variable
from brainpy._src.integrators import odeint, JointEq
from brainpy.types import Shape, ArrayType
-from .base import IhChannel, CalciumChannel
+from .base import IonChannel
__all__ = [
'Ih_HM1992',
@@ -21,7 +21,7 @@
]
-class Ih_HM1992(IhChannel):
+class Ih_HM1992(IonChannel):
r"""The hyperpolarization-activated cation current model propsoed by (Huguenard & McCormick, 1992) [1]_.
The hyperpolarization-activated cation current model is adopted from
@@ -55,6 +55,8 @@ class Ih_HM1992(IhChannel):
"""
+ master_type = HHTypedNeuron
+
def __init__(
self,
size: Shape,
@@ -63,13 +65,13 @@ def __init__(
E: Union[float, ArrayType, Initializer, Callable] = 43.,
phi: Union[float, ArrayType, Initializer, Callable] = 1.,
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
):
- super(Ih_HM1992, self).__init__(size,
- keep_size=keep_size,
- name=name,
- mode=mode)
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
# parameters
self.phi = parameter(phi, self.varshape, allow_none=False)
@@ -103,7 +105,7 @@ def f_p_tau(self, V):
return 1. / (bm.exp(-0.086 * V - 14.59) + bm.exp(0.0701 * V - 1.87))
-class Ih_De1996(IhChannel, CalciumChannel):
+class Ih_De1996(IonChannel):
r"""The hyperpolarization-activated cation current model propsoed by (Destexhe, et al., 1996) [1]_.
The full kinetic schema was
@@ -173,8 +175,8 @@ def __init__(
T_base: Union[float, ArrayType] = 3.,
phi: Union[float, ArrayType, Initializer, Callable] = None,
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
):
super().__init__(size,
keep_size=keep_size,
diff --git a/brainpy/_src/dyn/channels/leaky.py b/brainpy/_src/dyn/channels/leaky.py
index 5a6f1b5e1..981152534 100644
--- a/brainpy/_src/dyn/channels/leaky.py
+++ b/brainpy/_src/dyn/channels/leaky.py
@@ -5,20 +5,29 @@
"""
-from typing import Union, Callable
+from typing import Union, Callable, Sequence
import brainpy.math as bm
+from brainpy._src.dyn.neurons.hh import HHTypedNeuron
from brainpy._src.initialize import Initializer, parameter
-from brainpy.types import ArrayType, Shape
-
-from .base import LeakyChannel
+from brainpy.types import ArrayType
+from .base import IonChannel
__all__ = [
+ 'LeakyChannel',
'IL',
- 'IKL',
]
+class LeakyChannel(IonChannel):
+ """Base class for leaky channel dynamics."""
+
+ master_type = HHTypedNeuron
+
+ def reset_state(self, V, batch_size=None):
+ pass
+
+
class IL(LeakyChannel):
"""The leakage channel current.
@@ -32,7 +41,7 @@ class IL(LeakyChannel):
def __init__(
self,
- size,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
g_max: Union[int, float, ArrayType, Initializer, Callable] = 0.1,
E: Union[int, float, ArrayType, Initializer, Callable] = -70.,
@@ -57,34 +66,3 @@ def update(self, V):
def current(self, V):
return self.g_max * (self.E - V)
-
-
-class IKL(IL):
- """The potassium leak channel current.
-
- Parameters
- ----------
- g_max : float
- The potassium leakage conductance which is modulated by both
- acetylcholine and norepinephrine.
- E : float
- The reversal potential.
- """
-
- def __init__(
- self,
- size: Shape,
- keep_size: bool = False,
- g_max: Union[int, float, ArrayType, Initializer, Callable] = 0.005,
- E: Union[int, float, ArrayType, Initializer, Callable] = -90.,
- method: str = None,
- name: str = None,
- mode: bm.Mode = None,
- ):
- super(IKL, self).__init__(size=size,
- keep_size=keep_size,
- g_max=g_max,
- E=E,
- method=method,
- name=name,
- mode=mode)
diff --git a/brainpy/_src/dyn/channels/potassium.py b/brainpy/_src/dyn/channels/potassium.py
new file mode 100644
index 000000000..5ea82d859
--- /dev/null
+++ b/brainpy/_src/dyn/channels/potassium.py
@@ -0,0 +1,2038 @@
+# -*- coding: utf-8 -*-
+
+"""
+This module implements voltage-dependent potassium channels.
+
+"""
+
+from typing import Union, Callable, Optional, Sequence
+
+import brainpy.math as bm
+from brainpy._src.context import share
+from brainpy._src.dyn.ions.potassium import Potassium
+from brainpy._src.dyn.neurons.hh import HHTypedNeuron
+from brainpy._src.initialize import Initializer, parameter, variable
+from brainpy._src.integrators import odeint, JointEq
+from brainpy.types import ArrayType
+from .base import IonChannel
+
+__all__ = [
+ 'PotassiumChannel',
+ 'IKDR_Ba2002v2',
+ 'IK_TM1991v2',
+ 'IK_HH1952v2',
+ 'IKA1_HM1992v2',
+ 'IKA2_HM1992v2',
+ 'IKK2A_HM1992v2',
+ 'IKK2B_HM1992v2',
+ 'IKNI_Ya1989v2',
+ 'IK_Leak',
+]
+
+
+class PotassiumChannel(IonChannel):
+ """Base class for sodium channel dynamics."""
+
+ master_type = Potassium
+
+ def update(self, V, C, E):
+ raise NotImplementedError
+
+ def current(self, V, C, E):
+ raise NotImplementedError
+
+ def reset(self, V, C, E, batch_size: int = None):
+ self.reset_state(V, C, E, batch_size)
+
+ def reset_state(self, V, C, E, batch_size: int = None):
+ raise NotImplementedError('Must be implemented by the subclass.')
+
+
+class _IK_p4_markov_v2(PotassiumChannel):
+ r"""The delayed rectifier potassium channel of :math:`p^4`
+ current which described with first-order Markov chain.
+
+ This general potassium current model should have the form of
+
+ .. math::
+
+ \begin{aligned}
+ I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\
+ \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p)
+ \end{aligned}
+
+ where :math:`\phi` is a temperature-dependent factor.
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The object size.
+ keep_size: bool
+ Whether we use `size` to initialize the variable. Otherwise, variable shape
+ will be initialized as `num`.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ phi : float, ArrayType, Initializer, Callable
+ The temperature-dependent factor.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ phi: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
+
+ self.g_max = parameter(g_max, self.varshape, allow_none=False)
+ self.phi = parameter(phi, self.varshape, allow_none=False)
+
+ # variables
+ self.p = variable(bm.zeros, self.mode, self.varshape)
+
+ # function
+ self.integral = odeint(self.derivative, method=method)
+
+ def derivative(self, p, t, V):
+ return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p)
+
+ def update(self, V, C, E):
+ self.p.value = self.integral(self.p.value, share['t'], V, share['dt'])
+
+ def current(self, V, C, E):
+ return self.g_max * self.p ** 4 * (E - V)
+
+ def reset_state(self, V, C, E, batch_size=None):
+ alpha = self.f_p_alpha(V)
+ beta = self.f_p_beta(V)
+ self.p.value = alpha / (alpha + beta)
+ if batch_size is not None:
+ assert self.p.shape[0] == batch_size
+
+ def f_p_alpha(self, V):
+ raise NotImplementedError
+
+ def f_p_beta(self, V):
+ raise NotImplementedError
+
+
+class IKDR_Ba2002v2(_IK_p4_markov_v2):
+ r"""The delayed rectifier potassium channel current.
+
+ The potassium current model is adopted from (Bazhenov, et, al. 2002) [1]_.
+ It's dynamics is given by:
+
+ .. math::
+
+ \begin{aligned}
+ I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\
+ \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\
+ \alpha_{p} &=\frac{0.032\left(V-V_{sh}-15\right)}{1-\exp \left(-\left(V-V_{sh}-15\right) / 5\right)} \\
+ \beta_p &= 0.5 \exp \left(-\left(V-V_{sh}-10\right) / 40\right)
+ \end{aligned}
+
+ where :math:`\phi` is a temperature-dependent factor, which is given by
+ :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The object size.
+ keep_size: bool
+ Whether we use `size` to initialize the variable. Otherwise, variable shape
+ will be initialized as `num`.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ T_base : float, ArrayType
+ The brainpy_object of temperature factor.
+ T : float, ArrayType, Initializer, Callable
+ The temperature (Celsius, :math:`^{\circ}C`).
+ V_sh : float, ArrayType, Initializer, Callable
+ The shift of the membrane potential to spike.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+
+ References
+ ----------
+ .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations
+ and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704.
+
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ V_sh: Union[float, ArrayType, Initializer, Callable] = -50.,
+ T_base: Union[float, ArrayType] = 3.,
+ T: Union[float, ArrayType] = 36.,
+ phi: Optional[Union[float, ArrayType, Initializer, Callable]] = None,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ phi = T_base ** ((T - 36) / 10) if phi is None else phi
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ g_max=g_max,
+ phi=phi,
+ mode=mode)
+
+ # parameters
+ self.T = parameter(T, self.varshape, allow_none=False)
+ self.T_base = parameter(T_base, self.varshape, allow_none=False)
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_alpha(self, V):
+ tmp = V - self.V_sh - 15.
+ return 0.032 * tmp / (1. - bm.exp(-tmp / 5.))
+
+ def f_p_beta(self, V):
+ return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.)
+
+
+class IK_TM1991v2(_IK_p4_markov_v2):
+ r"""The potassium channel described by (Traub and Miles, 1991) [1]_.
+
+ The dynamics of this channel is given by:
+
+ .. math::
+
+ \begin{aligned}
+ I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\
+ \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\
+ \alpha_{p} &= 0.032 \frac{(15 - V + V_{sh})}{(\exp((15 - V + V_{sh}) / 5) - 1.)} \\
+ \beta_p &= 0.5 * \exp((10 - V + V_{sh}) / 40)
+ \end{aligned}
+
+ where :math:`V_{sh}` is the membrane shift (default -63 mV), and
+ :math:`\phi` is the temperature-dependent factor (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+
+ References
+ ----------
+ .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus.
+ Vol. 777. Cambridge University Press, 1991.
+
+ See Also
+ --------
+ INa_TM1991
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ phi: Union[float, ArrayType, Initializer, Callable] = 1.,
+ V_sh: Union[int, float, ArrayType, Initializer, Callable] = -60.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ phi=phi,
+ g_max=g_max,
+ mode=mode)
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_alpha(self, V):
+ c = 15 - V + self.V_sh
+ return 0.032 * c / (bm.exp(c / 5) - 1.)
+
+ def f_p_beta(self, V):
+ return 0.5 * bm.exp((10 - V + self.V_sh) / 40)
+
+
+class IK_HH1952v2(_IK_p4_markov_v2):
+ r"""The potassium channel described by Hodgkin–Huxley model [1]_.
+
+ The dynamics of this channel is given by:
+
+ .. math::
+
+ \begin{aligned}
+ I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\
+ \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\
+ \alpha_{p} &= \frac{0.01 (V -V_{sh} + 10)}{1-\exp \left(-\left(V-V_{sh}+ 10\right) / 10\right)} \\
+ \beta_p &= 0.125 \exp \left(-\left(V-V_{sh}+20\right) / 80\right)
+ \end{aligned}
+
+ where :math:`V_{sh}` is the membrane shift (default -45 mV), and
+ :math:`\phi` is the temperature-dependent factor (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+
+ References
+ ----------
+ .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of
+ membrane current and its application to conduction and excitation in
+ nerve." The Journal of physiology 117.4 (1952): 500.
+
+ See Also
+ --------
+ INa_HH1952
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ phi: Union[float, ArrayType, Initializer, Callable] = 1.,
+ V_sh: Union[int, float, ArrayType, Initializer, Callable] = -45.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ phi=phi,
+ g_max=g_max,
+ mode=mode)
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_alpha(self, V):
+ temp = V - self.V_sh + 10
+ return 0.01 * temp / (1 - bm.exp(-temp / 10))
+
+ def f_p_beta(self, V):
+ return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80)
+
+
+class _IKA_p4q_ss_v2(PotassiumChannel):
+ r"""The rapidly inactivating Potassium channel of :math:`p^4q`
+ current which described with steady-state format.
+
+ This model is developed according to the average behavior of
+ rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [3]_.
+
+ .. math::
+
+ &IA = g_{\mathrm{max}} p^4 q (E-V) \\
+ &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
+ &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
+
+ where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ phi_q : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`q`.
+
+ References
+ ----------
+ .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
+ currents involved in rhythmic oscillations in thalamic relay
+ neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
+ .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
+ TEA-sensitive K current in acutely isolated rat thalamic relay
+ neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
+
+ # parameters
+ self.g_max = parameter(g_max, self.varshape, allow_none=False)
+ self.phi_p = parameter(phi_p, self.varshape, allow_none=False)
+ self.phi_q = parameter(phi_q, self.varshape, allow_none=False)
+
+ # variables
+ self.p = variable(bm.zeros, self.mode, self.varshape)
+ self.q = variable(bm.zeros, self.mode, self.varshape)
+
+ # function
+ self.integral = odeint(JointEq(self.dp, self.dq), method=method)
+
+ def dp(self, p, t, V):
+ return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V)
+
+ def dq(self, q, t, V):
+ return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V)
+
+ def update(self, V, C, E):
+ self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt'])
+
+ def current(self, V, C, E):
+ return self.g_max * self.p ** 4 * self.q * (E - V)
+
+ def reset_state(self, V, C, E, batch_size=None):
+ self.p.value = self.f_p_inf(V)
+ self.q.value = self.f_q_inf(V)
+ if batch_size is not None:
+ assert self.p.shape[0] == batch_size
+ assert self.q.shape[0] == batch_size
+
+ def f_p_inf(self, V):
+ raise NotImplementedError
+
+ def f_p_tau(self, V):
+ raise NotImplementedError
+
+ def f_q_inf(self, V):
+ raise NotImplementedError
+
+ def f_q_tau(self, V):
+ raise NotImplementedError
+
+
+class IKA1_HM1992v2(_IKA_p4q_ss_v2):
+ r"""The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_.
+
+ This model is developed according to the average behavior of
+ rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_.
+
+ .. math::
+
+ &IA = g_{\mathrm{max}} p^4 q (E-V) \\
+ &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
+ &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 60)/8.5]} \\
+ &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\
+ &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
+ &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\
+ &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\
+ \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array}
+
+ where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ V_sh : float, ArrayType, Callable, Initializer
+ The membrane potential shift.
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ phi_q : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`q`.
+
+ References
+ ----------
+ .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
+ currents involved in rhythmic oscillations in thalamic relay
+ neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
+ .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
+ TEA-sensitive K current in acutely isolated rat thalamic relay
+ neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
+
+ See Also
+ --------
+ IKA2_HM1992
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 30.,
+ V_sh: Union[float, ArrayType, Initializer, Callable] = 0.,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ g_max=g_max,
+ phi_p=phi_p,
+ phi_q=phi_q,
+ mode=mode)
+
+ # parameters
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_inf(self, V):
+ return 1. / (1. + bm.exp(-(V - self.V_sh + 60.) / 8.5))
+
+ def f_p_tau(self, V):
+ return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) +
+ bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37
+
+ def f_q_inf(self, V):
+ return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.))
+
+ def f_q_tau(self, V):
+ return bm.where(V < -63 + self.V_sh,
+ 1. / (bm.exp((V - self.V_sh + 46.) / 5.) +
+ bm.exp(-(V - self.V_sh + 238.) / 37.5)),
+ 19.)
+
+
+class IKA2_HM1992v2(_IKA_p4q_ss_v2):
+ r"""The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_.
+
+ This model is developed according to the average behavior of
+ rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_.
+
+ .. math::
+
+ &IA = g_{\mathrm{max}} p^4 q (E-V) \\
+ &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
+ &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 36)/20.]} \\
+ &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\
+ &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
+ &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\
+ &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\
+ \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array}
+
+ where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ V_sh : float, ArrayType, Callable, Initializer
+ The membrane potential shift.
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ phi_q : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`q`.
+
+ References
+ ----------
+ .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
+ currents involved in rhythmic oscillations in thalamic relay
+ neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
+ .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
+ TEA-sensitive K current in acutely isolated rat thalamic relay
+ neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
+
+ See Also
+ --------
+ IKA1_HM1992
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 20.,
+ V_sh: Union[float, ArrayType, Initializer, Callable] = 0.,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ g_max=g_max,
+ phi_q=phi_q,
+ phi_p=phi_p,
+ mode=mode)
+
+ # parameters
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_inf(self, V):
+ return 1. / (1. + bm.exp(-(V - self.V_sh + 36.) / 20.))
+
+ def f_p_tau(self, V):
+ return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) +
+ bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37
+
+ def f_q_inf(self, V):
+ return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.))
+
+ def f_q_tau(self, V):
+ return bm.where(V < -63 + self.V_sh,
+ 1. / (bm.exp((V - self.V_sh + 46.) / 5.) +
+ bm.exp(-(V - self.V_sh + 238.) / 37.5)),
+ 19.)
+
+
+class _IKK2_pq_ss_v2(PotassiumChannel):
+ r"""The slowly inactivating Potassium channel of :math:`pq`
+ current which described with steady-state format.
+
+ The dynamics of the model is given as [2]_ [3]_.
+
+ .. math::
+
+ &IK2 = g_{\mathrm{max}} p q (E-V) \\
+ &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
+ &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
+
+ where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ phi_q : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`q`.
+
+ References
+ ----------
+ .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
+ currents involved in rhythmic oscillations in thalamic relay
+ neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
+ .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
+ TEA-sensitive K current in acutely isolated rat thalamic relay
+ neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
+
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
+
+ # parameters
+ self.g_max = parameter(g_max, self.varshape, allow_none=False)
+ self.phi_p = parameter(phi_p, self.varshape, allow_none=False)
+ self.phi_q = parameter(phi_q, self.varshape, allow_none=False)
+
+ # variables
+ self.p = variable(bm.zeros, self.mode, self.varshape)
+ self.q = variable(bm.zeros, self.mode, self.varshape)
+
+ # function
+ self.integral = odeint(JointEq(self.dp, self.dq), method=method)
+
+ def dp(self, p, t, V):
+ return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V)
+
+ def dq(self, q, t, V):
+ return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V)
+
+ def update(self, V, C, E):
+ self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt'])
+
+ def current(self, V, C, E):
+ return self.g_max * self.p * self.q * (E - V)
+
+ def reset_state(self, V, C, E, batch_size=None):
+ self.p.value = self.f_p_inf(V)
+ self.q.value = self.f_q_inf(V)
+ if batch_size is not None:
+ assert self.p.shape[0] == batch_size
+ assert self.q.shape[0] == batch_size
+
+ def f_p_inf(self, V):
+ raise NotImplementedError
+
+ def f_p_tau(self, V):
+ raise NotImplementedError
+
+ def f_q_inf(self, V):
+ raise NotImplementedError
+
+ def f_q_tau(self, V):
+ raise NotImplementedError
+
+
+class IKK2A_HM1992v2(_IKK2_pq_ss_v2):
+ r"""The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_.
+
+ The dynamics of the model is given as [2]_ [3]_.
+
+ .. math::
+
+ &IK2 = g_{\mathrm{max}} p q (E-V) \\
+ &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
+ &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\
+ &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+
+ \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\
+ &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
+ &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\
+ & \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + \exp((V -V_{sh}+130)/-7.1)} + 120 \\
+
+ where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ V_sh : float, ArrayType, Callable, Initializer
+ The membrane potential shift.
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ phi_q : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`q`.
+
+ References
+ ----------
+ .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
+ currents involved in rhythmic oscillations in thalamic relay
+ neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
+ .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
+ TEA-sensitive K current in acutely isolated rat thalamic relay
+ neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
+
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ V_sh: Union[float, ArrayType, Initializer, Callable] = 0.,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ phi_p=phi_p,
+ phi_q=phi_q,
+ g_max=g_max,
+ mode=mode)
+
+ # parameters
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_inf(self, V):
+ return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.))
+
+ def f_p_tau(self, V):
+ return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) +
+ bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9
+
+ def f_q_inf(self, V):
+ return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6))
+
+ def f_q_tau(self, V):
+ return 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) +
+ bm.exp(-(V - self.V_sh + 130.) / 7.1))
+
+
+class IKK2B_HM1992v2(_IKK2_pq_ss_v2):
+ r"""The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_.
+
+ The dynamics of the model is given as [2]_ [3]_.
+
+ .. math::
+
+ &IK2 = g_{\mathrm{max}} p q (E-V) \\
+ &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
+ &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\
+ &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+
+ \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\
+ &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
+ &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\
+ &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) +
+ \exp((V -V_{sh}+130)/-7.1)} + 120 \quad V<(-70+V_{sh})\, mV \\
+ \tau_{q} = 8.9 \quad V \geq (-70 + V_{sh})\, mV \end{array}
+
+ where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ V_sh : float, ArrayType, Callable, Initializer
+ The membrane potential shift.
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ phi_q : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`q`.
+
+ References
+ ----------
+ .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
+ currents involved in rhythmic oscillations in thalamic relay
+ neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
+ .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
+ TEA-sensitive K current in acutely isolated rat thalamic relay
+ neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
+
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ V_sh: Union[float, ArrayType, Initializer, Callable] = 0.,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ phi_p=phi_p,
+ phi_q=phi_q,
+ g_max=g_max,
+ mode=mode)
+
+ # parameters
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_inf(self, V):
+ return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.))
+
+ def f_p_tau(self, V):
+ return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) +
+ bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9
+
+ def f_q_inf(self, V):
+ return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6))
+
+ def f_q_tau(self, V):
+ return bm.where(V < -70 + self.V_sh,
+ 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) +
+ bm.exp(-(V - self.V_sh + 130.) / 7.1)),
+ 8.9)
+
+
+class IKNI_Ya1989v2(PotassiumChannel):
+ r"""A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_.
+
+ This slow potassium current can effectively account for spike-frequency adaptation.
+
+ .. math::
+
+ \begin{aligned}
+ &I_{M}=\bar{g}_{M} p\left(V-E_{K}\right) \\
+ &\frac{\mathrm{d} p}{\mathrm{~d} t}=\left(p_{\infty}(V)-p\right) / \tau_{p}(V) \\
+ &p_{\infty}(V)=\frac{1}{1+\exp [-(V-V_{sh}+35) / 10]} \\
+ &\tau_{p}(V)=\frac{\tau_{\max }}{3.3 \exp [(V-V_{sh}+35) / 20]+\exp [-(V-V_{sh}+35) / 20]}
+ \end{aligned}
+
+ where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and
+ :math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise.
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ V_sh : float, ArrayType, Callable, Initializer
+ The membrane potential shift.
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ tau_max: float, ArrayType, Callable, Initializer
+ The :math:`tau_{\max}` parameter.
+
+ References
+ ----------
+ .. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133.
+
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 0.004,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ tau_max: Union[float, ArrayType, Initializer, Callable] = 4e3,
+ V_sh: Union[float, ArrayType, Initializer, Callable] = 0.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
+
+ # parameters
+ self.g_max = parameter(g_max, self.varshape, allow_none=False)
+ self.tau_max = parameter(tau_max, self.varshape, allow_none=False)
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+ self.phi_p = parameter(phi_p, self.varshape, allow_none=False)
+ self.phi_q = parameter(phi_q, self.varshape, allow_none=False)
+
+ # variables
+ self.p = variable(bm.zeros, self.mode, self.varshape)
+
+ # function
+ self.integral = odeint(self.dp, method=method)
+
+ def dp(self, p, t, V):
+ return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V)
+
+ def update(self, V, C, E):
+ self.p.value = self.integral(self.p.value, share['t'], V, share['dt'])
+
+ def current(self, V, C, E):
+ return self.g_max * self.p * (E - V)
+
+ def reset_state(self, V, C, E, batch_size=None):
+ self.p.value = self.f_p_inf(V)
+ if batch_size is not None:
+ assert self.p.shape[0] == batch_size
+
+ def f_p_inf(self, V):
+ return 1. / (1. + bm.exp(-(V - self.V_sh + 35.) / 10.))
+
+ def f_p_tau(self, V):
+ temp = V - self.V_sh + 35.
+ return self.tau_max / (3.3 * bm.exp(temp / 20.) + bm.exp(-temp / 20.))
+
+
+class _IK_p4_markov(PotassiumChannel):
+ r"""The delayed rectifier potassium channel of :math:`p^4`
+ current which described with first-order Markov chain.
+
+ This general potassium current model should have the form of
+
+ .. math::
+
+ \begin{aligned}
+ I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\
+ \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p)
+ \end{aligned}
+
+ where :math:`\phi` is a temperature-dependent factor.
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The object size.
+ keep_size: bool
+ Whether we use `size` to initialize the variable. Otherwise, variable shape
+ will be initialized as `num`.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ phi : float, ArrayType, Initializer, Callable
+ The temperature-dependent factor.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+
+ """
+ master_type = HHTypedNeuron
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = -90.,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ phi: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
+
+ self.E = parameter(E, self.varshape, allow_none=False)
+ self.g_max = parameter(g_max, self.varshape, allow_none=False)
+ self.phi = parameter(phi, self.varshape, allow_none=False)
+
+ # variables
+ self.p = variable(bm.zeros, self.mode, self.varshape)
+
+ # function
+ self.integral = odeint(self.derivative, method=method)
+
+ def derivative(self, p, t, V):
+ return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p)
+
+ def update(self, V):
+ self.p.value = self.integral(self.p.value, share['t'], V, share['dt'])
+
+ def current(self, V):
+ return self.g_max * self.p ** 4 * (self.E - V)
+
+ def reset_state(self, V, batch_size=None):
+ alpha = self.f_p_alpha(V)
+ beta = self.f_p_beta(V)
+ self.p.value = alpha / (alpha + beta)
+ if batch_size is not None:
+ assert self.p.shape[0] == batch_size
+
+ def f_p_alpha(self, V):
+ raise NotImplementedError
+
+ def f_p_beta(self, V):
+ raise NotImplementedError
+
+
+class IKDR_Ba2002(_IK_p4_markov):
+ r"""The delayed rectifier potassium channel current.
+
+ The potassium current model is adopted from (Bazhenov, et, al. 2002) [1]_.
+ It's dynamics is given by:
+
+ .. math::
+
+ \begin{aligned}
+ I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\
+ \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\
+ \alpha_{p} &=\frac{0.032\left(V-V_{sh}-15\right)}{1-\exp \left(-\left(V-V_{sh}-15\right) / 5\right)} \\
+ \beta_p &= 0.5 \exp \left(-\left(V-V_{sh}-10\right) / 40\right)
+ \end{aligned}
+
+ where :math:`\phi` is a temperature-dependent factor, which is given by
+ :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The object size.
+ keep_size: bool
+ Whether we use `size` to initialize the variable. Otherwise, variable shape
+ will be initialized as `num`.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ T_base : float, ArrayType
+ The brainpy_object of temperature factor.
+ T : float, ArrayType, Initializer, Callable
+ The temperature (Celsius, :math:`^{\circ}C`).
+ V_sh : float, ArrayType, Initializer, Callable
+ The shift of the membrane potential to spike.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+
+ References
+ ----------
+ .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations
+ and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704.
+
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = -90.,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ V_sh: Union[float, ArrayType, Initializer, Callable] = -50.,
+ T_base: Union[float, ArrayType] = 3.,
+ T: Union[float, ArrayType] = 36.,
+ phi: Optional[Union[float, ArrayType, Initializer, Callable]] = None,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ phi = T_base ** ((T - 36) / 10) if phi is None else phi
+ super(IKDR_Ba2002, self).__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ g_max=g_max,
+ phi=phi,
+ E=E,
+ mode=mode)
+
+ # parameters
+ self.T = parameter(T, self.varshape, allow_none=False)
+ self.T_base = parameter(T_base, self.varshape, allow_none=False)
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_alpha(self, V):
+ tmp = V - self.V_sh - 15.
+ return 0.032 * tmp / (1. - bm.exp(-tmp / 5.))
+
+ def f_p_beta(self, V):
+ return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.)
+
+
+class IK_TM1991(_IK_p4_markov):
+ r"""The potassium channel described by (Traub and Miles, 1991) [1]_.
+
+ The dynamics of this channel is given by:
+
+ .. math::
+
+ \begin{aligned}
+ I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\
+ \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\
+ \alpha_{p} &= 0.032 \frac{(15 - V + V_{sh})}{(\exp((15 - V + V_{sh}) / 5) - 1.)} \\
+ \beta_p &= 0.5 * \exp((10 - V + V_{sh}) / 40)
+ \end{aligned}
+
+ where :math:`V_{sh}` is the membrane shift (default -63 mV), and
+ :math:`\phi` is the temperature-dependent factor (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+
+ References
+ ----------
+ .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus.
+ Vol. 777. Cambridge University Press, 1991.
+
+ See Also
+ --------
+ INa_TM1991
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = -90.,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ phi: Union[float, ArrayType, Initializer, Callable] = 1.,
+ V_sh: Union[int, float, ArrayType, Initializer, Callable] = -60.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super(IK_TM1991, self).__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ phi=phi,
+ E=E,
+ g_max=g_max,
+ mode=mode)
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_alpha(self, V):
+ c = 15 - V + self.V_sh
+ return 0.032 * c / (bm.exp(c / 5) - 1.)
+
+ def f_p_beta(self, V):
+ return 0.5 * bm.exp((10 - V + self.V_sh) / 40)
+
+
+class IK_HH1952(_IK_p4_markov):
+ r"""The potassium channel described by Hodgkin–Huxley model [1]_.
+
+ The dynamics of this channel is given by:
+
+ .. math::
+
+ \begin{aligned}
+ I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\
+ \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\
+ \alpha_{p} &= \frac{0.01 (V -V_{sh} + 10)}{1-\exp \left(-\left(V-V_{sh}+ 10\right) / 10\right)} \\
+ \beta_p &= 0.125 \exp \left(-\left(V-V_{sh}+20\right) / 80\right)
+ \end{aligned}
+
+ where :math:`V_{sh}` is the membrane shift (default -45 mV), and
+ :math:`\phi` is the temperature-dependent factor (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+
+ References
+ ----------
+ .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of
+ membrane current and its application to conduction and excitation in
+ nerve." The Journal of physiology 117.4 (1952): 500.
+
+ See Also
+ --------
+ INa_HH1952
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = -90.,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ phi: Union[float, ArrayType, Initializer, Callable] = 1.,
+ V_sh: Union[int, float, ArrayType, Initializer, Callable] = -45.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super(IK_HH1952, self).__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ phi=phi,
+ E=E,
+ g_max=g_max,
+ mode=mode)
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_alpha(self, V):
+ temp = V - self.V_sh + 10
+ return 0.01 * temp / (1 - bm.exp(-temp / 10))
+
+ def f_p_beta(self, V):
+ return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80)
+
+
+class _IKA_p4q_ss(PotassiumChannel):
+ r"""The rapidly inactivating Potassium channel of :math:`p^4q`
+ current which described with steady-state format.
+
+ This model is developed according to the average behavior of
+ rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [3]_.
+
+ .. math::
+
+ &IA = g_{\mathrm{max}} p^4 q (E-V) \\
+ &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
+ &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
+
+ where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ phi_q : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`q`.
+
+ References
+ ----------
+ .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
+ currents involved in rhythmic oscillations in thalamic relay
+ neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
+ .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
+ TEA-sensitive K current in acutely isolated rat thalamic relay
+ neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
+ """
+ master_type = HHTypedNeuron
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = -90.,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
+
+ # parameters
+ self.E = parameter(E, self.varshape, allow_none=False)
+ self.g_max = parameter(g_max, self.varshape, allow_none=False)
+ self.phi_p = parameter(phi_p, self.varshape, allow_none=False)
+ self.phi_q = parameter(phi_q, self.varshape, allow_none=False)
+
+ # variables
+ self.p = variable(bm.zeros, self.mode, self.varshape)
+ self.q = variable(bm.zeros, self.mode, self.varshape)
+
+ # function
+ self.integral = odeint(JointEq(self.dp, self.dq), method=method)
+
+ def dp(self, p, t, V):
+ return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V)
+
+ def dq(self, q, t, V):
+ return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V)
+
+ def update(self, V):
+ self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt'])
+
+ def current(self, V):
+ return self.g_max * self.p ** 4 * self.q * (self.E - V)
+
+ def reset_state(self, V, batch_size=None):
+ self.p.value = self.f_p_inf(V)
+ self.q.value = self.f_q_inf(V)
+ if batch_size is not None:
+ assert self.p.shape[0] == batch_size
+ assert self.q.shape[0] == batch_size
+
+ def f_p_inf(self, V):
+ raise NotImplementedError
+
+ def f_p_tau(self, V):
+ raise NotImplementedError
+
+ def f_q_inf(self, V):
+ raise NotImplementedError
+
+ def f_q_tau(self, V):
+ raise NotImplementedError
+
+
+class IKA1_HM1992(_IKA_p4q_ss):
+ r"""The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_.
+
+ This model is developed according to the average behavior of
+ rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_.
+
+ .. math::
+
+ &IA = g_{\mathrm{max}} p^4 q (E-V) \\
+ &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
+ &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 60)/8.5]} \\
+ &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\
+ &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
+ &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\
+ &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\
+ \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array}
+
+ where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ V_sh : float, ArrayType, Callable, Initializer
+ The membrane potential shift.
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ phi_q : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`q`.
+
+ References
+ ----------
+ .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
+ currents involved in rhythmic oscillations in thalamic relay
+ neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
+ .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
+ TEA-sensitive K current in acutely isolated rat thalamic relay
+ neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
+
+ See Also
+ --------
+ IKA2_HM1992
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = -90.,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 30.,
+ V_sh: Union[float, ArrayType, Initializer, Callable] = 0.,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super(IKA1_HM1992, self).__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ E=E,
+ g_max=g_max,
+ phi_p=phi_p,
+ phi_q=phi_q,
+ mode=mode)
+
+ # parameters
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_inf(self, V):
+ return 1. / (1. + bm.exp(-(V - self.V_sh + 60.) / 8.5))
+
+ def f_p_tau(self, V):
+ return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) +
+ bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37
+
+ def f_q_inf(self, V):
+ return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.))
+
+ def f_q_tau(self, V):
+ return bm.where(V < -63 + self.V_sh,
+ 1. / (bm.exp((V - self.V_sh + 46.) / 5.) +
+ bm.exp(-(V - self.V_sh + 238.) / 37.5)),
+ 19.)
+
+
+class IKA2_HM1992(_IKA_p4q_ss):
+ r"""The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_.
+
+ This model is developed according to the average behavior of
+ rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_.
+
+ .. math::
+
+ &IA = g_{\mathrm{max}} p^4 q (E-V) \\
+ &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
+ &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 36)/20.]} \\
+ &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\
+ &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
+ &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\
+ &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\
+ \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array}
+
+ where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ V_sh : float, ArrayType, Callable, Initializer
+ The membrane potential shift.
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ phi_q : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`q`.
+
+ References
+ ----------
+ .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
+ currents involved in rhythmic oscillations in thalamic relay
+ neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
+ .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
+ TEA-sensitive K current in acutely isolated rat thalamic relay
+ neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
+
+ See Also
+ --------
+ IKA1_HM1992
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = -90.,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 20.,
+ V_sh: Union[float, ArrayType, Initializer, Callable] = 0.,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super(IKA2_HM1992, self).__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ E=E,
+ g_max=g_max,
+ phi_q=phi_q,
+ phi_p=phi_p,
+ mode=mode)
+
+ # parameters
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_inf(self, V):
+ return 1. / (1. + bm.exp(-(V - self.V_sh + 36.) / 20.))
+
+ def f_p_tau(self, V):
+ return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) +
+ bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37
+
+ def f_q_inf(self, V):
+ return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.))
+
+ def f_q_tau(self, V):
+ return bm.where(V < -63 + self.V_sh,
+ 1. / (bm.exp((V - self.V_sh + 46.) / 5.) +
+ bm.exp(-(V - self.V_sh + 238.) / 37.5)),
+ 19.)
+
+
+class _IKK2_pq_ss(PotassiumChannel):
+ r"""The slowly inactivating Potassium channel of :math:`pq`
+ current which described with steady-state format.
+
+ The dynamics of the model is given as [2]_ [3]_.
+
+ .. math::
+
+ &IK2 = g_{\mathrm{max}} p q (E-V) \\
+ &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
+ &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
+
+ where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ phi_q : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`q`.
+
+ References
+ ----------
+ .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
+ currents involved in rhythmic oscillations in thalamic relay
+ neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
+ .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
+ TEA-sensitive K current in acutely isolated rat thalamic relay
+ neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
+
+ """
+ master_type = HHTypedNeuron
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = -90.,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
+
+ # parameters
+ self.E = parameter(E, self.varshape, allow_none=False)
+ self.g_max = parameter(g_max, self.varshape, allow_none=False)
+ self.phi_p = parameter(phi_p, self.varshape, allow_none=False)
+ self.phi_q = parameter(phi_q, self.varshape, allow_none=False)
+
+ # variables
+ self.p = variable(bm.zeros, self.mode, self.varshape)
+ self.q = variable(bm.zeros, self.mode, self.varshape)
+
+ # function
+ self.integral = odeint(JointEq(self.dp, self.dq), method=method)
+
+ def dp(self, p, t, V):
+ return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V)
+
+ def dq(self, q, t, V):
+ return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V)
+
+ def update(self, V):
+ self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt'])
+
+ def current(self, V):
+ return self.g_max * self.p * self.q * (self.E - V)
+
+ def reset_state(self, V, batch_size=None):
+ self.p.value = self.f_p_inf(V)
+ self.q.value = self.f_q_inf(V)
+ if batch_size is not None:
+ assert self.p.shape[0] == batch_size
+ assert self.q.shape[0] == batch_size
+
+ def f_p_inf(self, V):
+ raise NotImplementedError
+
+ def f_p_tau(self, V):
+ raise NotImplementedError
+
+ def f_q_inf(self, V):
+ raise NotImplementedError
+
+ def f_q_tau(self, V):
+ raise NotImplementedError
+
+
+class IKK2A_HM1992(_IKK2_pq_ss):
+ r"""The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_.
+
+ The dynamics of the model is given as [2]_ [3]_.
+
+ .. math::
+
+ &IK2 = g_{\mathrm{max}} p q (E-V) \\
+ &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
+ &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\
+ &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+
+ \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\
+ &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
+ &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\
+ & \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + \exp((V -V_{sh}+130)/-7.1)} + 120 \\
+
+ where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ V_sh : float, ArrayType, Callable, Initializer
+ The membrane potential shift.
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ phi_q : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`q`.
+
+ References
+ ----------
+ .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
+ currents involved in rhythmic oscillations in thalamic relay
+ neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
+ .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
+ TEA-sensitive K current in acutely isolated rat thalamic relay
+ neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
+
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = -90.,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ V_sh: Union[float, ArrayType, Initializer, Callable] = 0.,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super(IKK2A_HM1992, self).__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ phi_p=phi_p,
+ phi_q=phi_q,
+ g_max=g_max,
+ E=E,
+ mode=mode)
+
+ # parameters
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_inf(self, V):
+ return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.))
+
+ def f_p_tau(self, V):
+ return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) +
+ bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9
+
+ def f_q_inf(self, V):
+ return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6))
+
+ def f_q_tau(self, V):
+ return 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) +
+ bm.exp(-(V - self.V_sh + 130.) / 7.1))
+
+
+class IKK2B_HM1992(_IKK2_pq_ss):
+ r"""The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_.
+
+ The dynamics of the model is given as [2]_ [3]_.
+
+ .. math::
+
+ &IK2 = g_{\mathrm{max}} p q (E-V) \\
+ &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\
+ &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\
+ &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+
+ \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\
+ &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\
+ &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\
+ &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) +
+ \exp((V -V_{sh}+130)/-7.1)} + 120 \quad V<(-70+V_{sh})\, mV \\
+ \tau_{q} = 8.9 \quad V \geq (-70 + V_{sh})\, mV \end{array}
+
+ where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.).
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ V_sh : float, ArrayType, Callable, Initializer
+ The membrane potential shift.
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ phi_q : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`q`.
+
+ References
+ ----------
+ .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the
+ currents involved in rhythmic oscillations in thalamic relay
+ neurons." Journal of neurophysiology 68.4 (1992): 1373-1383.
+ .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a
+ TEA-sensitive K current in acutely isolated rat thalamic relay
+ neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
+
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = -90.,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ V_sh: Union[float, ArrayType, Initializer, Callable] = 0.,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super(IKK2B_HM1992, self).__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ phi_p=phi_p,
+ phi_q=phi_q,
+ g_max=g_max,
+ E=E,
+ mode=mode)
+
+ # parameters
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_inf(self, V):
+ return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.))
+
+ def f_p_tau(self, V):
+ return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) +
+ bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9
+
+ def f_q_inf(self, V):
+ return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6))
+
+ def f_q_tau(self, V):
+ return bm.where(V < -70 + self.V_sh,
+ 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) +
+ bm.exp(-(V - self.V_sh + 130.) / 7.1)),
+ 8.9)
+
+
+class IKNI_Ya1989(PotassiumChannel):
+ r"""A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_.
+
+ This slow potassium current can effectively account for spike-frequency adaptation.
+
+ .. math::
+
+ \begin{aligned}
+ &I_{M}=\bar{g}_{M} p\left(V-E_{K}\right) \\
+ &\frac{\mathrm{d} p}{\mathrm{~d} t}=\left(p_{\infty}(V)-p\right) / \tau_{p}(V) \\
+ &p_{\infty}(V)=\frac{1}{1+\exp [-(V-V_{sh}+35) / 10]} \\
+ &\tau_{p}(V)=\frac{\tau_{\max }}{3.3 \exp [(V-V_{sh}+35) / 20]+\exp [-(V-V_{sh}+35) / 20]}
+ \end{aligned}
+
+ where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and
+ :math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise.
+
+ Parameters
+ ----------
+ size: int, sequence of int
+ The geometry size.
+ method: str
+ The numerical integration method.
+ name: str
+ The object name.
+ g_max : float, ArrayType, Initializer, Callable
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Initializer, Callable
+ The reversal potential (mV).
+ V_sh : float, ArrayType, Callable, Initializer
+ The membrane potential shift.
+ phi_p : optional, float, ArrayType, Callable, Initializer
+ The temperature factor for channel :math:`p`.
+ tau_max: float, ArrayType, Callable, Initializer
+ The :math:`tau_{\max}` parameter.
+
+ References
+ ----------
+ .. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133.
+
+ """
+ master_type = HHTypedNeuron
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = -90.,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 0.004,
+ phi_p: Union[float, ArrayType, Initializer, Callable] = 1.,
+ phi_q: Union[float, ArrayType, Initializer, Callable] = 1.,
+ tau_max: Union[float, ArrayType, Initializer, Callable] = 4e3,
+ V_sh: Union[float, ArrayType, Initializer, Callable] = 0.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super(IKNI_Ya1989, self).__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
+
+ # parameters
+ self.E = parameter(E, self.varshape, allow_none=False)
+ self.g_max = parameter(g_max, self.varshape, allow_none=False)
+ self.tau_max = parameter(tau_max, self.varshape, allow_none=False)
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+ self.phi_p = parameter(phi_p, self.varshape, allow_none=False)
+ self.phi_q = parameter(phi_q, self.varshape, allow_none=False)
+
+ # variables
+ self.p = variable(bm.zeros, self.mode, self.varshape)
+
+ # function
+ self.integral = odeint(self.dp, method=method)
+
+ def dp(self, p, t, V):
+ return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V)
+
+ def update(self, V):
+ self.p.value = self.integral(self.p.value, share['t'], V, share['dt'])
+
+ def current(self, V):
+ return self.g_max * self.p * (self.E - V)
+
+ def reset_state(self, V, batch_size=None):
+ self.p.value = self.f_p_inf(V)
+ if batch_size is not None:
+ assert self.p.shape[0] == batch_size
+
+ def f_p_inf(self, V):
+ return 1. / (1. + bm.exp(-(V - self.V_sh + 35.) / 10.))
+
+ def f_p_tau(self, V):
+ temp = V - self.V_sh + 35.
+ return self.tau_max / (3.3 * bm.exp(temp / 20.) + bm.exp(-temp / 20.))
+
+
+class IK_Leak(PotassiumChannel):
+ """The potassium leak channel current.
+
+ Parameters
+ ----------
+ g_max : float
+ The potassium leakage conductance which is modulated by both
+ acetylcholine and norepinephrine.
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ g_max: Union[int, float, ArrayType, Initializer, Callable] = 0.005,
+ method: str = None,
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size=size,
+ keep_size=keep_size,
+ method=method,
+ name=name,
+ mode=mode)
+ self.g_max = self.init_param(g_max, self.varshape)
+
+ def reset_state(self, V, C, E, batch_size: int = None):
+ pass
+
+ def update(self, V, C, E):
+ pass
+
+ def current(self, V, C, E):
+ return self.g_max * (E - V)
diff --git a/brainpy/_src/dyn/channels/potassium_calcium.py b/brainpy/_src/dyn/channels/potassium_calcium.py
new file mode 100644
index 000000000..c74bb80f0
--- /dev/null
+++ b/brainpy/_src/dyn/channels/potassium_calcium.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+
+
+"""
+This module implements calcium-dependent potassium channels.
+"""
+
+from typing import Union, Callable, Optional
+
+import brainpy.math as bm
+from brainpy._src.context import share
+from brainpy._src.dyn.ions.calcium import Calcium
+from brainpy._src.dyn.ions.potassium import Potassium
+from brainpy._src.initialize import Initializer, parameter, variable
+from brainpy._src.integrators.ode.generic import odeint
+from brainpy._src.mixin import JointType
+from brainpy.types import Shape, ArrayType
+from .calcium import CalciumChannel
+from .potassium import PotassiumChannel
+
+__all__ = [
+ 'IAHP_De1994v2',
+]
+
+
+class KCaChannel(PotassiumChannel, CalciumChannel):
+ pass
+
+
+class IAHP_De1994v2(KCaChannel):
+ r"""The calcium-dependent potassium current model proposed by (Destexhe, et al., 1994) [1]_.
+
+ Both in vivo (Contreras et al. 1993; Mulle et al. 1986) and in
+ vitro recordings (Avanzini et al. 1989) show the presence of a
+ marked after-hyper-polarization (AHP) after each burst of the RE
+ cell. This slow AHP is mediated by a slow :math:`Ca^{2+}`-dependent K+
+ current (Bal and McCormick 1993). (Destexhe, et al., 1994) adopted a
+ modified version of a model of :math:`I_{KCa}` introduced previously (Yamada et al.
+ 1989) that requires the binding of :math:`nCa^{2+}` to open the channel
+
+ .. math::
+
+ (\text { closed })+n \mathrm{Ca}_{i}^{2+} \underset{\beta}{\stackrel{\alpha}{\rightleftharpoons}(\text { open })
+
+ where :math:`Ca_i^{2+}` is the intracellular calcium and :math:`\alpha` and
+ :math:`\beta` are rate constants. The ionic current is then given by
+
+ .. math::
+
+ \begin{aligned}
+ I_{AHP} &= g_{\mathrm{max}} p^2 (V - E_K) \\
+ {dp \over dt} &= \phi {p_{\infty}(V, [Ca^{2+}]_i) - p \over \tau_p(V, [Ca^{2+}]_i)} \\
+ p_{\infty} &=\frac{\alpha[Ca^{2+}]_i^n}{\left(\alpha[Ca^{2+}]_i^n + \beta\right)} \\
+ \tau_p &=\frac{1}{\left(\alpha[Ca^{2+}]_i +\beta\right)}
+ \end{aligned}
+
+ where :math:`E` is the reversal potential, :math:`g_{max}` is the maximum conductance,
+ :math:`[Ca^{2+}]_i` is the intracellular Calcium concentration.
+ The values :math:`n=2, \alpha=48 \mathrm{~ms}^{-1} \mathrm{mM}^{-2}` and
+ :math:`\beta=0.03 \mathrm{~ms}^{-1}` yielded AHPs very similar to those RE cells
+ recorded in vivo and in vitro.
+
+ Parameters
+ ----------
+ g_max : float
+ The maximal conductance density (:math:`mS/cm^2`).
+
+ References
+ ----------
+
+ .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated
+ thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818.
+
+ """
+
+ '''The type of the master object.'''
+ master_type = JointType[Calcium, Potassium]
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ n: Union[float, ArrayType, Initializer, Callable] = 2,
+ g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
+ alpha: Union[float, ArrayType, Initializer, Callable] = 48.,
+ beta: Union[float, ArrayType, Initializer, Callable] = 0.09,
+ phi: Union[float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(size=size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
+
+ # parameters
+ self.g_max = parameter(g_max, self.varshape, allow_none=False)
+ self.n = parameter(n, self.varshape, allow_none=False)
+ self.alpha = parameter(alpha, self.varshape, allow_none=False)
+ self.beta = parameter(beta, self.varshape, allow_none=False)
+ self.phi = parameter(phi, self.varshape, allow_none=False)
+
+ # variables
+ self.p = variable(bm.zeros, self.mode, self.varshape)
+
+ # function
+ self.integral = odeint(self.dp, method=method)
+
+ def dp(self, p, t, C_Ca):
+ C2 = self.alpha * bm.power(C_Ca, self.n)
+ C3 = C2 + self.beta
+ return self.phi * (C2 / C3 - p) * C3
+
+ def update(self, V, Ca_info, K_info):
+ self.p.value = self.integral(self.p.value, share['t'], C_Ca=Ca_info['C'], dt=share['dt'])
+
+ def current(self, V, Ca_info, K_info):
+ return self.g_max * self.p * self.p * (K_info['E'] - V)
+
+ def reset_state(self, V, Ca_info, K_info, batch_size=None):
+ C2 = self.alpha * bm.power(Ca_info['C'], self.n)
+ C3 = C2 + self.beta
+ if batch_size is None:
+ self.p.value = bm.broadcast_to(C2 / C3, self.varshape)
+ else:
+ self.p.value = bm.broadcast_to(C2 / C3, (batch_size,) + self.varshape)
+ assert self.p.shape[0] == batch_size
diff --git a/brainpy/_src/dyn/channels/KCa.py b/brainpy/_src/dyn/channels/potassium_calcium_compatible.py
similarity index 96%
rename from brainpy/_src/dyn/channels/KCa.py
rename to brainpy/_src/dyn/channels/potassium_calcium_compatible.py
index 28c53e64f..add47f169 100644
--- a/brainpy/_src/dyn/channels/KCa.py
+++ b/brainpy/_src/dyn/channels/potassium_calcium_compatible.py
@@ -8,20 +8,20 @@
from typing import Union, Callable
-from brainpy._src.context import share
import brainpy.math as bm
+from brainpy._src.context import share
+from brainpy._src.dyn.ions.calcium import Calcium
from brainpy._src.initialize import Initializer, parameter, variable
from brainpy._src.integrators.ode.generic import odeint
from brainpy.types import Shape, ArrayType
-from .base import CalciumChannel, PotassiumChannel
-from brainpy._src.dyn.ions.base import Calcium
+from .base import IonChannel
__all__ = [
'IAHP_De1994',
]
-class IAHP_De1994(PotassiumChannel, CalciumChannel):
+class IAHP_De1994(IonChannel):
r"""The calcium-dependent potassium current model proposed by (Destexhe, et al., 1994) [1]_.
Both in vivo (Contreras et al. 1993; Mulle et al. 1986) and in
@@ -124,3 +124,4 @@ def reset_state(self, V, C_Ca, E_Ca, batch_size=None):
else:
self.p.value = bm.broadcast_to(C2 / C3, (batch_size,) + self.varshape)
assert self.p.shape[0] == batch_size
+
diff --git a/brainpy/_src/dyn/channels/K.py b/brainpy/_src/dyn/channels/potassium_compatible.py
similarity index 93%
rename from brainpy/_src/dyn/channels/K.py
rename to brainpy/_src/dyn/channels/potassium_compatible.py
index 93f19a95e..d9bb41b61 100644
--- a/brainpy/_src/dyn/channels/K.py
+++ b/brainpy/_src/dyn/channels/potassium_compatible.py
@@ -5,27 +5,27 @@
"""
-from typing import Union, Callable, Optional
+from typing import Union, Callable, Optional, Sequence
import brainpy.math as bm
from brainpy._src.context import share
+from brainpy._src.dyn.channels.leaky import LeakyChannel
+from brainpy._src.dyn.neurons.hh import HHTypedNeuron
from brainpy._src.initialize import Initializer, parameter, variable
from brainpy._src.integrators import odeint, JointEq
-from brainpy.types import Shape, ArrayType
-from .base import PotassiumChannel
+from brainpy.types import ArrayType
+from .potassium import PotassiumChannel
__all__ = [
'IKDR_Ba2002',
'IK_TM1991',
'IK_HH1952',
-
'IKA1_HM1992',
'IKA2_HM1992',
-
'IKK2A_HM1992',
'IKK2B_HM1992',
-
'IKNI_Ya1989',
+ 'IKL',
]
@@ -63,10 +63,11 @@ class _IK_p4_markov(PotassiumChannel):
The object name.
"""
+ master_type = HHTypedNeuron
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
E: Union[float, ArrayType, Initializer, Callable] = -90.,
g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
@@ -75,10 +76,10 @@ def __init__(
name: str = None,
mode: bm.Mode = None,
):
- super(_IK_p4_markov, self).__init__(size,
- keep_size=keep_size,
- name=name,
- mode=mode)
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
self.E = parameter(E, self.varshape, allow_none=False)
self.g_max = parameter(g_max, self.varshape, allow_none=False)
@@ -162,7 +163,7 @@ class IKDR_Ba2002(_IK_p4_markov):
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
E: Union[float, ArrayType, Initializer, Callable] = -90.,
g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
@@ -239,7 +240,7 @@ class IK_TM1991(_IK_p4_markov):
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
E: Union[float, ArrayType, Initializer, Callable] = -90.,
g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
@@ -310,7 +311,7 @@ class IK_HH1952(_IK_p4_markov):
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
E: Union[float, ArrayType, Initializer, Callable] = -90.,
g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
@@ -379,10 +380,11 @@ class _IKA_p4q_ss(PotassiumChannel):
TEA-sensitive K current in acutely isolated rat thalamic relay
neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
"""
+ master_type = HHTypedNeuron
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
E: Union[float, ArrayType, Initializer, Callable] = -90.,
g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
@@ -392,10 +394,10 @@ def __init__(
name: str = None,
mode: bm.Mode = None,
):
- super(_IKA_p4q_ss, self).__init__(size,
- keep_size=keep_size,
- name=name,
- mode=mode)
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
# parameters
self.E = parameter(E, self.varshape, allow_none=False)
@@ -496,7 +498,7 @@ class IKA1_HM1992(_IKA_p4q_ss):
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
E: Union[float, ArrayType, Initializer, Callable] = -90.,
g_max: Union[float, ArrayType, Initializer, Callable] = 30.,
@@ -591,7 +593,7 @@ class IKA2_HM1992(_IKA_p4q_ss):
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
E: Union[float, ArrayType, Initializer, Callable] = -90.,
g_max: Union[float, ArrayType, Initializer, Callable] = 20.,
@@ -673,10 +675,11 @@ class _IKK2_pq_ss(PotassiumChannel):
neurons." Journal of neurophysiology 66.4 (1991): 1316-1328.
"""
+ master_type = HHTypedNeuron
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
E: Union[float, ArrayType, Initializer, Callable] = -90.,
g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
@@ -686,10 +689,10 @@ def __init__(
name: str = None,
mode: bm.Mode = None,
):
- super(_IKK2_pq_ss, self).__init__(size,
- keep_size=keep_size,
- name=name,
- mode=mode)
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
# parameters
self.E = parameter(E, self.varshape, allow_none=False)
@@ -786,7 +789,7 @@ class IKK2A_HM1992(_IKK2_pq_ss):
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
E: Union[float, ArrayType, Initializer, Callable] = -90.,
g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
@@ -822,7 +825,7 @@ def f_q_inf(self, V):
def f_q_tau(self, V):
return 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) +
- bm.exp(-(V - self.V_sh + 130.) / 7.1))
+ bm.exp(-(V - self.V_sh + 130.) / 7.1))
class IKK2B_HM1992(_IKK2_pq_ss):
@@ -877,7 +880,7 @@ class IKK2B_HM1992(_IKK2_pq_ss):
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
E: Union[float, ArrayType, Initializer, Callable] = -90.,
g_max: Union[float, ArrayType, Initializer, Callable] = 10.,
@@ -913,9 +916,9 @@ def f_q_inf(self, V):
def f_q_tau(self, V):
return bm.where(V < -70 + self.V_sh,
- 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) +
- bm.exp(-(V - self.V_sh + 130.) / 7.1)),
- 8.9)
+ 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) +
+ bm.exp(-(V - self.V_sh + 130.) / 7.1)),
+ 8.9)
class IKNI_Ya1989(PotassiumChannel):
@@ -959,10 +962,11 @@ class IKNI_Ya1989(PotassiumChannel):
.. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133.
"""
+ master_type = HHTypedNeuron
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
E: Union[float, ArrayType, Initializer, Callable] = -90.,
g_max: Union[float, ArrayType, Initializer, Callable] = 0.004,
@@ -1013,3 +1017,44 @@ def f_p_inf(self, V):
def f_p_tau(self, V):
temp = V - self.V_sh + 35.
return self.tau_max / (3.3 * bm.exp(temp / 20.) + bm.exp(-temp / 20.))
+
+
+class IKL(LeakyChannel):
+ """The potassium leak channel current.
+
+ Parameters
+ ----------
+ g_max : float
+ The potassium leakage conductance which is modulated by both
+ acetylcholine and norepinephrine.
+ E : float
+ The reversal potential.
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ g_max: Union[int, float, ArrayType, Initializer, Callable] = 0.005,
+ E: Union[int, float, ArrayType, Initializer, Callable] = -90.,
+ method: str = None,
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
+
+ self.E = parameter(E, self.varshape, allow_none=False)
+ self.g_max = parameter(g_max, self.varshape, allow_none=False)
+ self.method = method
+
+ def reset_state(self, V, batch_size=None):
+ pass
+
+ def update(self, V):
+ pass
+
+ def current(self, V):
+ return self.g_max * (self.E - V)
diff --git a/brainpy/_src/dyn/channels/sodium.py b/brainpy/_src/dyn/channels/sodium.py
new file mode 100644
index 000000000..66e93a45e
--- /dev/null
+++ b/brainpy/_src/dyn/channels/sodium.py
@@ -0,0 +1,381 @@
+# -*- coding: utf-8 -*-
+
+"""
+This module implements voltage-dependent sodium channels.
+
+"""
+
+from typing import Union, Callable
+
+import brainpy.math as bm
+from brainpy._src.context import share
+from brainpy._src.dyn.ions.sodium import Sodium
+from brainpy._src.initialize import Initializer, parameter, variable
+from brainpy._src.integrators import odeint, JointEq
+from brainpy.types import ArrayType, Shape
+from .base import IonChannel
+
+__all__ = [
+ 'SodiumChannel',
+ 'INa_Ba2002v2',
+ 'INa_TM1991v2',
+ 'INa_HH1952v2',
+]
+
+
+class SodiumChannel(IonChannel):
+ """Base class for sodium channel dynamics."""
+
+ master_type = Sodium
+
+ def update(self, V, C, E):
+ raise NotImplementedError
+
+ def current(self, V, C, E):
+ raise NotImplementedError
+
+ def reset(self, V, C, E, batch_size: int = None):
+ self.reset_state(V, C, E, batch_size)
+
+ def reset_state(self, V, C, E, batch_size: int = None):
+ raise NotImplementedError('Must be implemented by the subclass.')
+
+
+class _INa_p3q_markov_v2(SodiumChannel):
+ r"""The sodium current model of :math:`p^3q` current which described with first-order Markov chain.
+
+ The general model can be used to model the dynamics with:
+
+ .. math::
+
+ \begin{aligned}
+ I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\
+ \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\
+ \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\
+ \end{aligned}
+
+ where :math:`\phi` is a temperature-dependent factor.
+
+ Parameters
+ ----------
+ g_max : float, ArrayType, Callable, Initializer
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Callable, Initializer
+ The reversal potential (mV).
+ phi : float, ArrayType, Callable, Initializer
+ The temperature-dependent factor.
+ method: str
+ The numerical method
+ name: str
+ The name of the object.
+
+ """
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ g_max: Union[int, float, ArrayType, Initializer, Callable] = 90.,
+ phi: Union[int, float, ArrayType, Initializer, Callable] = 1.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size=size,
+ keep_size=keep_size,
+ name=name,
+ mode=mode)
+
+ # parameters
+ self.phi = parameter(phi, self.varshape, allow_none=False)
+ self.g_max = parameter(g_max, self.varshape, allow_none=False)
+
+ # variables
+ self.p = variable(bm.zeros, self.mode, self.varshape)
+ self.q = variable(bm.zeros, self.mode, self.varshape)
+
+ # function
+ self.integral = odeint(JointEq([self.dp, self.dq]), method=method)
+
+ def reset_state(self, V, C, E, batch_size=None):
+ alpha = self.f_p_alpha(V)
+ beta = self.f_p_beta(V)
+ self.p.value = alpha / (alpha + beta)
+ alpha = self.f_q_alpha(V)
+ beta = self.f_q_beta(V)
+ self.q.value = alpha / (alpha + beta)
+ if batch_size is not None:
+ assert self.p.shape[0] == batch_size
+ assert self.q.shape[0] == batch_size
+
+ def dp(self, p, t, V):
+ return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p)
+
+ def dq(self, q, t, V):
+ return self.phi * (self.f_q_alpha(V) * (1. - q) - self.f_q_beta(V) * q)
+
+ def update(self, V, C, E):
+ p, q = self.integral(self.p, self.q, share['t'], V, share['dt'])
+ self.p.value, self.q.value = p, q
+
+ def current(self, V, C, E):
+ return self.g_max * self.p ** 3 * self.q * (E - V)
+
+ def f_p_alpha(self, V):
+ raise NotImplementedError
+
+ def f_p_beta(self, V):
+ raise NotImplementedError
+
+ def f_q_alpha(self, V):
+ raise NotImplementedError
+
+ def f_q_beta(self, V):
+ raise NotImplementedError
+
+
+class INa_Ba2002v2(_INa_p3q_markov_v2):
+ r"""The sodium current model.
+
+ The sodium current model is adopted from (Bazhenov, et, al. 2002) [1]_.
+ It's dynamics is given by:
+
+ .. math::
+
+ \begin{aligned}
+ I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\
+ \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\
+ \alpha_{p} &=\frac{0.32\left(V-V_{sh}-13\right)}{1-\exp \left(-\left(V-V_{sh}-13\right) / 4\right)} \\
+ \beta_{p} &=\frac{-0.28\left(V-V_{sh}-40\right)}{1-\exp \left(\left(V-V_{sh}-40\right) / 5\right)} \\
+ \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\
+ \alpha_q &=0.128 \exp \left(-\left(V-V_{sh}-17\right) / 18\right) \\
+ \beta_q &= \frac{4}{1+\exp \left(-\left(V-V_{sh}-40\right) / 5\right)}
+ \end{aligned}
+
+ where :math:`\phi` is a temperature-dependent factor, which is given by
+ :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius).
+
+ Parameters
+ ----------
+ g_max : float, ArrayType, Callable, Initializer
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Callable, Initializer
+ The reversal potential (mV).
+ T : float, ArrayType
+ The temperature (Celsius, :math:`^{\circ}C`).
+ V_sh : float, ArrayType, Callable, Initializer
+ The shift of the membrane potential to spike.
+
+ References
+ ----------
+
+ .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations
+ and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704.
+
+ See Also
+ --------
+ INa_TM1991
+ """
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ T: Union[int, float, ArrayType] = 36.,
+ g_max: Union[int, float, ArrayType, Initializer, Callable] = 90.,
+ V_sh: Union[int, float, ArrayType, Initializer, Callable] = -50.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ phi=3 ** ((T - 36) / 10),
+ g_max=g_max,
+ mode=mode)
+ self.T = parameter(T, self.varshape, allow_none=False)
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_alpha(self, V):
+ temp = V - self.V_sh - 13.
+ return 0.32 * temp / (1. - bm.exp(-temp / 4.))
+
+ def f_p_beta(self, V):
+ temp = V - self.V_sh - 40.
+ return -0.28 * temp / (1. - bm.exp(temp / 5.))
+
+ def f_q_alpha(self, V):
+ return 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.)
+
+ def f_q_beta(self, V):
+ return 4. / (1. + bm.exp(-(V - self.V_sh - 40.) / 5.))
+
+
+class INa_TM1991v2(_INa_p3q_markov_v2):
+ r"""The sodium current model described by (Traub and Miles, 1991) [1]_.
+
+ The dynamics of this sodium current model is given by:
+
+ .. math::
+
+ \begin{split}
+ \begin{aligned}
+ I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\
+ \frac {dm} {dt} &= \phi(\alpha_m (1-x) - \beta_m) \\
+ &\alpha_m(V) = 0.32 \frac{(13 - V + V_{sh})}{\exp((13 - V +V_{sh}) / 4) - 1.} \\
+ &\beta_m(V) = 0.28 \frac{(V - V_{sh} - 40)}{(\exp((V - V_{sh} - 40) / 5) - 1)} \\
+ \frac {dh} {dt} &= \phi(\alpha_h (1-x) - \beta_h) \\
+ &\alpha_h(V) = 0.128 * \exp((17 - V + V_{sh}) / 18) \\
+ &\beta_h(V) = 4. / (1 + \exp(-(V - V_{sh} - 40) / 5)) \\
+ \end{aligned}
+ \end{split}
+
+ where :math:`V_{sh}` is the membrane shift (default -63 mV), and
+ :math:`\phi` is the temperature-dependent factor (default 1.).
+
+ Parameters
+ ----------
+ size: int, tuple of int
+ The size of the simulation target.
+ keep_size: bool
+ Keep size or flatten the size?
+ method: str
+ The numerical method
+ name: str
+ The name of the object.
+ g_max : float, ArrayType, Callable, Initializer
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Callable, Initializer
+ The reversal potential (mV).
+ V_sh: float, ArrayType, Callable, Initializer
+ The membrane shift.
+
+ References
+ ----------
+ .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus.
+ Vol. 777. Cambridge University Press, 1991.
+
+ See Also
+ --------
+ INa_Ba2002
+ """
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ g_max: Union[int, float, ArrayType, Initializer, Callable] = 120.,
+ phi: Union[int, float, ArrayType, Initializer, Callable] = 1.,
+ V_sh: Union[int, float, ArrayType, Initializer, Callable] = -63.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ phi=phi,
+ g_max=g_max,
+ mode=mode)
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_alpha(self, V):
+ temp = 13 - V + self.V_sh
+ return 0.32 * temp / (bm.exp(temp / 4) - 1.)
+
+ def f_p_beta(self, V):
+ temp = V - self.V_sh - 40
+ return 0.28 * temp / (bm.exp(temp / 5) - 1)
+
+ def f_q_alpha(self, V):
+ return 0.128 * bm.exp((17 - V + self.V_sh) / 18)
+
+ def f_q_beta(self, V):
+ return 4. / (1 + bm.exp(-(V - self.V_sh - 40) / 5))
+
+
+class INa_HH1952v2(_INa_p3q_markov_v2):
+ r"""The sodium current model described by Hodgkin–Huxley model [1]_.
+
+ The dynamics of this sodium current model is given by:
+
+ .. math::
+
+ \begin{split}
+ \begin{aligned}
+ I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\
+ \frac {dm} {dt} &= \phi (\alpha_m (1-x) - \beta_m) \\
+ &\alpha_m(V) = \frac {0.1(V-V_{sh}-5)}{1-\exp(\frac{-(V -V_{sh} -5)} {10})} \\
+ &\beta_m(V) = 4.0 \exp(\frac{-(V -V_{sh}+ 20)} {18}) \\
+ \frac {dh} {dt} &= \phi (\alpha_h (1-x) - \beta_h) \\
+ &\alpha_h(V) = 0.07 \exp(\frac{-(V-V_{sh}+20)}{20}) \\
+ &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V -V_{sh}-10)} {10})} \\
+ \end{aligned}
+ \end{split}
+
+ where :math:`V_{sh}` is the membrane shift (default -45 mV), and
+ :math:`\phi` is the temperature-dependent factor (default 1.).
+
+ Parameters
+ ----------
+ size: int, tuple of int
+ The size of the simulation target.
+ keep_size: bool
+ Keep size or flatten the size?
+ method: str
+ The numerical method
+ name: str
+ The name of the object.
+ g_max : float, ArrayType, Callable, Initializer
+ The maximal conductance density (:math:`mS/cm^2`).
+ E : float, ArrayType, Callable, Initializer
+ The reversal potential (mV).
+ V_sh: float, ArrayType, Callable, Initializer
+ The membrane shift.
+
+ References
+ ----------
+ .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of
+ membrane current and its application to conduction and excitation in
+ nerve." The Journal of physiology 117.4 (1952): 500.
+
+ See Also
+ --------
+ IK_HH1952
+ """
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ g_max: Union[int, float, ArrayType, Initializer, Callable] = 120.,
+ phi: Union[int, float, ArrayType, Initializer, Callable] = 1.,
+ V_sh: Union[int, float, ArrayType, Initializer, Callable] = -45.,
+ method: str = 'exp_auto',
+ name: str = None,
+ mode: bm.Mode = None,
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ name=name,
+ method=method,
+ phi=phi,
+ g_max=g_max,
+ mode=mode)
+ self.V_sh = parameter(V_sh, self.varshape, allow_none=False)
+
+ def f_p_alpha(self, V):
+ temp = V - self.V_sh - 5
+ return 0.1 * temp / (1 - bm.exp(-temp / 10))
+
+ def f_p_beta(self, V):
+ return 4.0 * bm.exp(-(V - self.V_sh + 20) / 18)
+
+ def f_q_alpha(self, V):
+ return 0.07 * bm.exp(-(V - self.V_sh + 20) / 20.)
+
+ def f_q_beta(self, V):
+ return 1 / (1 + bm.exp(-(V - self.V_sh - 10) / 10))
diff --git a/brainpy/_src/dyn/channels/Na.py b/brainpy/_src/dyn/channels/sodium_compatible.py
similarity index 96%
rename from brainpy/_src/dyn/channels/Na.py
rename to brainpy/_src/dyn/channels/sodium_compatible.py
index d29189ae8..9a05593b0 100644
--- a/brainpy/_src/dyn/channels/Na.py
+++ b/brainpy/_src/dyn/channels/sodium_compatible.py
@@ -5,14 +5,15 @@
"""
-from typing import Union, Callable
+from typing import Union, Callable, Sequence
import brainpy.math as bm
from brainpy._src.context import share
+from brainpy._src.dyn.neurons.hh import HHTypedNeuron
from brainpy._src.initialize import Initializer, parameter, variable
from brainpy._src.integrators import odeint, JointEq
-from brainpy.types import ArrayType, Shape
-from .base import SodiumChannel
+from brainpy.types import ArrayType
+from .sodium import SodiumChannel
__all__ = [
'INa_Ba2002',
@@ -50,12 +51,13 @@ class _INa_p3q_markov(SodiumChannel):
The name of the object.
"""
+ master_type = HHTypedNeuron
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
- E: Union[int, float, ArrayType, Initializer, Callable] = 50.,
+ E: Union[int, float, ArrayType, Initializer, Callable] = None,
g_max: Union[int, float, ArrayType, Initializer, Callable] = 90.,
phi: Union[int, float, ArrayType, Initializer, Callable] = 1.,
method: str = 'exp_auto',
@@ -161,7 +163,7 @@ class INa_Ba2002(_INa_p3q_markov):
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
T: Union[int, float, ArrayType] = 36.,
E: Union[int, float, ArrayType, Initializer, Callable] = 50.,
@@ -248,7 +250,7 @@ class INa_TM1991(_INa_p3q_markov):
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
E: Union[int, float, ArrayType, Initializer, Callable] = 50.,
g_max: Union[int, float, ArrayType, Initializer, Callable] = 120.,
@@ -335,7 +337,7 @@ class INa_HH1952(_INa_p3q_markov):
def __init__(
self,
- size: Shape,
+ size: Union[int, Sequence[int]],
keep_size: bool = False,
E: Union[int, float, ArrayType, Initializer, Callable] = 50.,
g_max: Union[int, float, ArrayType, Initializer, Callable] = 120.,
diff --git a/brainpy/_src/dyn/ions/__init__.py b/brainpy/_src/dyn/ions/__init__.py
index d9d4e9c37..ee840a720 100644
--- a/brainpy/_src/dyn/ions/__init__.py
+++ b/brainpy/_src/dyn/ions/__init__.py
@@ -1,3 +1,5 @@
from .base import *
-from .ca import *
+from .calcium import *
+from .potassium import *
+from .sodium import *
diff --git a/brainpy/_src/dyn/ions/base.py b/brainpy/_src/dyn/ions/base.py
index bee8c08c2..804e551bc 100644
--- a/brainpy/_src/dyn/ions/base.py
+++ b/brainpy/_src/dyn/ions/base.py
@@ -1,61 +1,147 @@
# -*- coding: utf-8 -*-
-from typing import Union
+from typing import Union, Optional, Dict, Sequence, Callable
import brainpy.math as bm
-from brainpy._src.dyn.neurons.hh import CondNeuGroup
from brainpy._src.dyn.base import IonChaDyn
-from brainpy._src.mixin import Container, TreeNode
+from brainpy._src.dyn.neurons.hh import HHTypedNeuron
+from brainpy._src.mixin import Container, TreeNode, _JointGenericAlias
from brainpy.types import Shape
__all__ = [
+ 'MixIons',
+ 'mix_ions',
'Ion',
- 'Calcium',
]
-class Ion(IonChaDyn, TreeNode):
- """Base class for ions."""
+class MixIons(IonChaDyn, Container, TreeNode):
+ """Mixing Ions.
- '''The type of the master object.'''
- master_type = CondNeuGroup
+ Args:
+ ions: Instances of ions. This option defines the master types of all children objects.
+ channels: Instance of channels.
+ """
+ master_type = HHTypedNeuron
+
+ def __init__(self, *ions, **channels):
+ # TODO: check "ions" should be independent from each other
+ assert isinstance(ions, (tuple, list)), f'{self.__class__.__name__} requires at least two ions. '
+ assert len(ions) >= 2, f'{self.__class__.__name__} requires at least two ions. '
+ assert all([isinstance(cls, Ion) for cls in ions]), f'Must be a sequence of Ion. But got {ions}.'
+ super().__init__(size=ions[0].size, keep_size=ions[0].keep_size, sharding=ions[0].sharding)
+
+ self.ions: Sequence['Ion'] = tuple(ions)
+ self._ion_classes = tuple([type(ion) for ion in self.ions])
+ self.children = bm.node_dict()
+ for k, v in channels.items():
+ self.add_elem(k=v)
def update(self, V):
- raise NotImplementedError('Must be implemented by the subclass.')
+ nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values())
+ self.check_hierarchies(self._ion_classes, *nodes)
+ for node in nodes:
+ infos = tuple([self._get_imp(root).pack_info() for root in node.master_type.__args__])
+ node.update(V, *infos)
+
+ def current(self, V):
+ """Generate ion channel current.
- def reset(self, V, batch_size=None):
- self.reset_state(V, batch_size)
+ Args:
+ V: The membrane potential.
+
+ Returns:
+ Current.
+ """
+ nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values())
+ self.check_hierarchies(self._ion_classes, *nodes)
+
+ if len(nodes) == 0:
+ return 0.
+ else:
+ current = 0.
+ for node in nodes:
+ infos = tuple([self._get_imp(root).pack_info() for root in node.master_type.__args__])
+ current = current + node.current(V, *infos)
+ return current
def reset_state(self, V, batch_size=None):
- raise NotImplementedError('Must be implemented by the subclass.')
+ nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values())
+ self.check_hierarchies(self._ion_classes, *nodes)
+ for node in nodes:
+ infos = tuple([self._get_imp(root).pack_info() for root in node.master_type.__args__])
+ node.reset_state(V, *infos, batch_size)
+
+ def check_hierarchy(self, roots, leaf):
+ # 'master_type' should be a brainpy.mixin.JointType
+ self._check_master_type(leaf)
+ for cls in leaf.master_type.__args__:
+ if not any([issubclass(root, cls) for root in roots]):
+ raise TypeError(f'Type does not match. {leaf} requires a master with type '
+ f'of {leaf.master_type}, but the master type now is {roots}.')
+
+ def add_elem(self, **elements):
+ """Add new elements.
+
+ Args:
+ elements: children objects.
+ """
+ self.check_hierarchies(self._ion_classes, **elements)
+ self.children.update(self.format_elements(IonChaDyn, **elements))
+ for key, elem in elements.items():
+ for ion_root in elem.master_type.__args__:
+ ion = self._get_imp(ion_root)
+ ion.add_external_current(elem.name, self._get_ion_fun(ion, elem))
+
+ def _get_ion_fun(self, ion, node):
+ def fun(V, *args):
+ infos = tuple([(ion.pack_info(*args)
+ if isinstance(ion, root) else
+ self._get_imp(root).pack_info())
+ for root in node.master_type.__args__])
+ return node.current(V, *infos)
+ return fun
+
+ def _get_imp(self, cls):
+ for ion in self.ions:
+ if isinstance(ion, cls):
+ return ion
+ else:
+ raise ValueError(f'No instance of {cls} is found.')
- def current(self, V):
- raise NotImplementedError('Must be implemented by the subclass.')
+ def _check_master_type(self, leaf):
+ if not isinstance(leaf.master_type, _JointGenericAlias):
+ raise TypeError(f'{self.__class__.__name__} requires leaf nodes that have the master_type of '
+ f'"brainpy.mixin.JointType". However, we got {leaf.master_type}')
+
+
+def mix_ions(*ions) -> MixIons:
+ """Create mixed ions.
- def clear_input(self):
- pass
+ Args:
+ ions: Ion instances.
- def __repr__(self):
- return f'{self.name}(size={self.size})'
+ Returns:
+ Instance of MixIons.
+ """
+ for ion in ions:
+ assert isinstance(ion, Ion), f'Must be instance of {Ion.__name__}. But got {type(ion)}'
+ assert len(ions) > 0, ''
+ return MixIons(*ions)
-class Calcium(Ion, Container):
+class Ion(IonChaDyn, Container, TreeNode):
"""The brainpy_object calcium dynamics.
- Parameters
- ----------
- size: int, sequence of int
- The size of the simulation target.
- method: str
- The numerical integration method.
- name: str
- The name of the object.
- **channels
- The calcium dependent channels.
+ Args:
+ size: The size of the simulation target.
+ method: The numerical integration method.
+ name: The name of the object.
+ channels: The calcium dependent channels.
"""
'''The type of the master object.'''
- master_type = CondNeuGroup
+ master_type = HHTypedNeuron
"""Reversal potential."""
E: Union[float, bm.Variable, bm.Array]
@@ -68,29 +154,57 @@ def __init__(
size: Shape,
keep_size: bool = False,
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
**channels
):
super().__init__(size, keep_size=keep_size, mode=mode, method=method, name=name)
-
self.children = bm.node_dict(self.format_elements(IonChaDyn, **channels))
+ self.external: Dict[str, Callable] = dict() # not found by `.nodes()` or `.vars()`
def update(self, V):
for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values():
node.update(V, self.C, self.E)
- def current(self, V, C_Ca=None, E_Ca=None):
- C_Ca = self.C if (C_Ca is None) else C_Ca
- E_Ca = self.E if (E_Ca is None) else E_Ca
+ def current(self, V, C=None, E=None):
+ """Generate ion channel current.
+
+ Args:
+ V: The membrane potential.
+ C: The ion concentration.
+ E: The reversal potential.
+
+ Returns:
+ Current.
+ """
+ C = self.C if (C is None) else C
+ E = self.E if (E is None) else E
nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values())
+ self.check_hierarchies(type(self), *nodes)
- if len(nodes) == 0:
- return 0.
- else:
- self.check_hierarchies(self.__class__, *nodes)
- current = nodes[0].current(V, C_Ca, E_Ca)
- for node in nodes[1:]:
- current += node.current(V, C_Ca, E_Ca)
- return current
+ current = 0.
+ if len(nodes) > 0:
+ for node in nodes:
+ current = current + node.current(V, C, E)
+ for key, node in self.external.items():
+ current = current + node(V, C, E)
+ return current
+
+ def reset_state(self, V, batch_size=None):
+ nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values())
+ self.check_hierarchies(type(self), *nodes)
+ for node in nodes:
+ node.reset_state(V, self.C, self.E, batch_size)
+
+ def pack_info(self, C=None, E=None) -> Dict:
+ if C is None:
+ C = self.C
+ if E is None:
+ E = self.E
+ return dict(C=C, E=E)
+
+ def add_external_current(self, key: str, fun: Callable):
+ if key in self.external:
+ raise ValueError
+ self.external[key] = fun
diff --git a/brainpy/_src/dyn/ions/ca.py b/brainpy/_src/dyn/ions/calcium.py
similarity index 84%
rename from brainpy/_src/dyn/ions/ca.py
rename to brainpy/_src/dyn/ions/calcium.py
index 89bc2d2d1..4fa50daed 100644
--- a/brainpy/_src/dyn/ions/ca.py
+++ b/brainpy/_src/dyn/ions/calcium.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-from typing import Union, Callable
+from typing import Union, Callable, Optional
import brainpy.math as bm
from brainpy._src.context import share
@@ -8,15 +8,20 @@
from brainpy._src.initialize import OneInit, Initializer, parameter, variable
from brainpy._src.integrators.ode.generic import odeint
from brainpy.types import Shape, ArrayType
-from .base import Calcium
+from .base import Ion
__all__ = [
+ 'Calcium',
'CalciumFixed',
'CalciumDetailed',
'CalciumFirstOrder',
]
+class Calcium(Ion):
+ pass
+
+
class CalciumFixed(Calcium):
"""Fixed Calcium dynamics.
@@ -31,16 +36,16 @@ def __init__(
E: Union[float, ArrayType, Initializer, Callable] = 120.,
C: Union[float, ArrayType, Initializer, Callable] = 2.4e-4,
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
**channels
):
- super(CalciumFixed, self).__init__(size,
- keep_size=keep_size,
- method=method,
- name=name,
- mode=mode,
- **channels)
+ super().__init__(size,
+ keep_size=keep_size,
+ method=method,
+ name=name,
+ mode=mode,
+ **channels)
self.E = parameter(E, self.varshape, allow_none=False)
self.C = parameter(C, self.varshape, allow_none=False)
@@ -82,16 +87,16 @@ def __init__(
T: Union[float, ArrayType, Initializer, Callable] = 36.,
C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4),
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
**channels
):
- super(CalciumDyna, self).__init__(size,
- keep_size=keep_size,
- method=method,
- name=name,
- mode=mode,
- **channels)
+ super().__init__(size,
+ keep_size=keep_size,
+ method=method,
+ name=name,
+ mode=mode,
+ **channels)
# parameters
self.C0 = parameter(C0, self.varshape, allow_none=False)
@@ -248,19 +253,19 @@ def __init__(
C0: Union[float, ArrayType, Initializer, Callable] = 2.,
C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4),
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
**channels
):
- super(CalciumDetailed, self).__init__(size,
- keep_size=keep_size,
- method=method,
- name=name,
- T=T,
- C0=C0,
- C_initializer=C_initializer,
- mode=mode,
- **channels)
+ super().__init__(size,
+ keep_size=keep_size,
+ method=method,
+ name=name,
+ T=T,
+ C0=C0,
+ C_initializer=C_initializer,
+ mode=mode,
+ **channels)
# parameters
self.d = parameter(d, self.varshape, allow_none=False)
@@ -292,19 +297,19 @@ def __init__(
C0: Union[float, ArrayType, Initializer, Callable] = 2.,
C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4),
method: str = 'exp_auto',
- name: str = None,
- mode: bm.Mode = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
**channels
):
- super(CalciumFirstOrder, self).__init__(size,
- keep_size=keep_size,
- method=method,
- name=name,
- T=T,
- C0=C0,
- C_initializer=C_initializer,
- mode=mode,
- **channels)
+ super().__init__(size,
+ keep_size=keep_size,
+ method=method,
+ name=name,
+ T=T,
+ C0=C0,
+ C_initializer=C_initializer,
+ mode=mode,
+ **channels)
# parameters
self.alpha = parameter(alpha, self.varshape, allow_none=False)
@@ -314,4 +319,3 @@ def derivative(self, C, t, V):
ICa = self.current(V, C, self.E)
drive = bm.maximum(- self.alpha * ICa, 0.)
return drive - self.beta * C
-
diff --git a/brainpy/_src/dyn/ions/potassium.py b/brainpy/_src/dyn/ions/potassium.py
new file mode 100644
index 000000000..b13c92458
--- /dev/null
+++ b/brainpy/_src/dyn/ions/potassium.py
@@ -0,0 +1,52 @@
+from typing import Union, Callable, Optional
+
+import brainpy.math as bm
+from brainpy._src.dyn.base import IonChaDyn
+from brainpy._src.initialize import Initializer
+from brainpy.types import Shape, ArrayType
+from .base import Ion
+
+__all__ = [
+ 'Potassium',
+ 'PotassiumFixed',
+]
+
+
+class Potassium(Ion):
+ pass
+
+
+class PotassiumFixed(Potassium):
+ """Fixed Sodium dynamics.
+
+ This calcium model has no dynamics. It holds fixed reversal
+ potential :math:`E` and concentration :math:`C`.
+ """
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = -950.,
+ C: Union[float, ArrayType, Initializer, Callable] = 0.0400811,
+ method: str = 'exp_auto',
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ **channels
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ method=method,
+ name=name,
+ mode=mode,
+ **channels)
+ self.E = self.init_param(E, self.varshape)
+ self.C = self.init_param(C, self.varshape)
+
+ def reset_state(self, V, C=None, E=None, batch_size=None):
+ C = self.C if C is None else C
+ E = self.E if E is None else E
+ nodes = self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()
+ self.check_hierarchies(type(self), *tuple(nodes))
+ for node in nodes:
+ node.reset_state(V, C, E, batch_size)
diff --git a/brainpy/_src/dyn/ions/sodium.py b/brainpy/_src/dyn/ions/sodium.py
new file mode 100644
index 000000000..28a37d69f
--- /dev/null
+++ b/brainpy/_src/dyn/ions/sodium.py
@@ -0,0 +1,52 @@
+from typing import Union, Callable, Optional
+
+import brainpy.math as bm
+from brainpy._src.dyn.base import IonChaDyn
+from brainpy._src.initialize import Initializer, parameter
+from brainpy.types import Shape, ArrayType
+from .base import Ion
+
+__all__ = [
+ 'Sodium',
+ 'SodiumFixed',
+]
+
+
+class Sodium(Ion):
+ pass
+
+
+class SodiumFixed(Sodium):
+ """Fixed Sodium dynamics.
+
+ This calcium model has no dynamics. It holds fixed reversal
+ potential :math:`E` and concentration :math:`C`.
+ """
+
+ def __init__(
+ self,
+ size: Shape,
+ keep_size: bool = False,
+ E: Union[float, ArrayType, Initializer, Callable] = 50.,
+ C: Union[float, ArrayType, Initializer, Callable] = 0.0400811,
+ method: str = 'exp_auto',
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ **channels
+ ):
+ super().__init__(size,
+ keep_size=keep_size,
+ method=method,
+ name=name,
+ mode=mode,
+ **channels)
+ self.E = parameter(E, self.varshape, allow_none=False)
+ self.C = parameter(C, self.varshape, allow_none=False)
+
+ def reset_state(self, V, C=None, E=None, batch_size=None):
+ C = self.C if C is None else C
+ E = self.E if E is None else E
+ nodes = self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()
+ self.check_hierarchies(type(self), *tuple(nodes))
+ for node in nodes:
+ node.reset_state(V, C, E, batch_size)
diff --git a/brainpy/_src/dyn/ions/tests/test_MixIons.py b/brainpy/_src/dyn/ions/tests/test_MixIons.py
new file mode 100644
index 000000000..b2731968e
--- /dev/null
+++ b/brainpy/_src/dyn/ions/tests/test_MixIons.py
@@ -0,0 +1,98 @@
+import brainpy as bp
+import brainpy.math as bm
+
+import unittest
+
+
+class TestMixIons(unittest.TestCase):
+ def test_init(self):
+ class HH(bp.dyn.CondNeuGroup):
+ def __init__(self, size):
+ super().__init__(size)
+
+ self.k = bp.dyn.PotassiumFixed(size)
+ self.ca = bp.dyn.CalciumFirstOrder(size)
+
+ self.kca = bp.dyn.mix_ions(self.k, self.ca)
+ self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size))
+
+ bm.random.seed()
+ HH(10)
+
+ def test_init2(self):
+ class HH(bp.dyn.CondNeuGroup):
+ def __init__(self, size):
+ super().__init__(size)
+
+ self.k = bp.dyn.PotassiumFixed(size)
+ self.ca = bp.dyn.CalciumFirstOrder(size)
+
+ self.kca = bp.dyn.mix_ions(self.k, self.ca)
+ self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size))
+ self.kca.add_elem(na=bp.dyn.INa_Ba2002(size))
+
+ bm.random.seed()
+ with self.assertRaises(TypeError):
+ HH(10)
+
+ def test_init3(self):
+ class HH(bp.dyn.CondNeuGroup):
+ def __init__(self, size):
+ super().__init__(size)
+
+ self.na = bp.dyn.SodiumFixed(size)
+ self.ca = bp.dyn.CalciumFirstOrder(size)
+
+ self.kca = bp.dyn.mix_ions(self.na, self.ca)
+ self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size))
+ self.kca.add_elem(na=bp.dyn.INa_Ba2002(size))
+
+ bm.random.seed()
+ with self.assertRaises(TypeError):
+ HH(10)
+
+ def test_init4(self):
+ class HH(bp.dyn.CondNeuGroup):
+ def __init__(self, size):
+ super().__init__(size)
+
+ self.na = bp.dyn.SodiumFixed(size)
+ self.k = bp.dyn.PotassiumFixed(size)
+ self.ca = bp.dyn.CalciumFirstOrder(size)
+
+ self.kca = bp.dyn.mix_ions(self.na, self.k, self.ca)
+ self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size))
+
+ bm.random.seed()
+ HH(10)
+
+
+class TestMixIons2(unittest.TestCase):
+ def test_current1(self):
+ class HH(bp.dyn.CondNeuGroup):
+ def __init__(self, size):
+ super().__init__(size)
+
+ self.k = bp.dyn.PotassiumFixed(size)
+ self.na = bp.dyn.SodiumFixed(size)
+ self.ca = bp.dyn.CalciumFirstOrder(size)
+ self.kca = bp.dyn.MixIons(self.na, self.k, self.ca)
+
+ self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size))
+
+ bm.random.seed()
+ hh = HH(10)
+
+ hh.reset_state()
+
+ ICa = hh.ca.current(hh.V)
+ INa = hh.na.current(hh.V)
+ IK = hh.k.current(hh.V)
+ print(ICa, INa, IK)
+
+ self.assertTrue(bm.allclose(INa, 0.))
+ self.assertTrue(bm.allclose(ICa, IK))
+
+
+
+
diff --git a/brainpy/dyn/channels.py b/brainpy/dyn/channels.py
index 11809476a..eff433df8 100644
--- a/brainpy/dyn/channels.py
+++ b/brainpy/dyn/channels.py
@@ -2,8 +2,8 @@
IonChannel,
)
-from brainpy._src.dyn.channels.base import CalciumChannel
-from brainpy._src.dyn.channels.Ca import (
+from brainpy._src.dyn.channels.calcium import (
+ CalciumChannel,
ICaN_IS2008,
ICaT_HM1992,
ICaT_HP1992,
@@ -13,8 +13,19 @@
)
-from brainpy._src.dyn.channels.base import PotassiumChannel
-from brainpy._src.dyn.channels.K import (
+from brainpy._src.dyn.channels.potassium import (
+ PotassiumChannel,
+ IKDR_Ba2002v2,
+ IK_TM1991v2,
+ IK_HH1952v2,
+ IKA1_HM1992v2,
+ IKA2_HM1992v2,
+ IKK2A_HM1992v2,
+ IKK2B_HM1992v2,
+ IKNI_Ya1989v2,
+ IK_Leak,
+)
+from brainpy._src.dyn.channels.potassium_compatible import (
IKDR_Ba2002,
IK_TM1991,
IK_HH1952,
@@ -23,32 +34,42 @@
IKK2A_HM1992,
IKK2B_HM1992,
IKNI_Ya1989,
+ IKL,
)
-from brainpy._src.dyn.channels.base import IhChannel
-from brainpy._src.dyn.channels.IH import (
+from brainpy._src.dyn.channels.hyperpolarization_activated import (
+ IhChannel,
Ih_HM1992,
Ih_De1996,
)
-from brainpy._src.dyn.channels.KCa import (
+from brainpy._src.dyn.channels.potassium_calcium import (
+ IAHP_De1994v2
+)
+from brainpy._src.dyn.channels.potassium_calcium_compatible import (
IAHP_De1994
)
-from brainpy._src.dyn.channels.base import SodiumChannel
-from brainpy._src.dyn.channels.Na import (
+from brainpy._src.dyn.channels.sodium import (
+ SodiumChannel,
+)
+from brainpy._src.dyn.channels.sodium_compatible import (
INa_Ba2002,
INa_TM1991,
INa_HH1952,
)
+from brainpy._src.dyn.channels.sodium import (
+ INa_Ba2002v2,
+ INa_TM1991v2,
+ INa_HH1952v2,
+)
-from brainpy._src.dyn.channels.base import LeakyChannel
from brainpy._src.dyn.channels.leaky import (
+ LeakyChannel,
IL,
- IKL,
)
diff --git a/brainpy/dyn/ions.py b/brainpy/dyn/ions.py
index 8f040c971..d5b6bb254 100644
--- a/brainpy/dyn/ions.py
+++ b/brainpy/dyn/ions.py
@@ -1,12 +1,26 @@
+"""
+``brainpy.dyn.ions`` module defines the behavior of ion dynamics.
+"""
+
from brainpy._src.dyn.ions.base import (
Ion as Ion,
- Calcium as Calcium,
+ mix_ions as mix_ions,
+ MixIons as MixIons,
)
-
-from brainpy._src.dyn.ions.ca import (
+from brainpy._src.dyn.ions.calcium import (
+ Calcium as Calcium,
CalciumFixed as CalciumFixed,
CalciumDetailed as CalciumDetailed,
CalciumFirstOrder as CalciumFirstOrder,
)
+from brainpy._src.dyn.ions.sodium import (
+ Sodium as Sodium,
+ SodiumFixed as SodiumFixed,
+)
+from brainpy._src.dyn.ions.potassium import (
+ Potassium as Potassium,
+ PotassiumFixed as PotassiumFixed,
+)
+
diff --git a/brainpy/dyn/neurons.py b/brainpy/dyn/neurons.py
index ae4d06ee8..c8304c875 100644
--- a/brainpy/dyn/neurons.py
+++ b/brainpy/dyn/neurons.py
@@ -32,6 +32,7 @@
)
from brainpy._src.dyn.neurons.hh import (
+ HHTypedNeuron,
CondNeuGroupLTC,
CondNeuGroup,
HH,
From 1163d609074fd52774a0fd3e756a3b9bb8429425 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Tue, 11 Jul 2023 21:35:42 +0800
Subject: [PATCH 028/326] new updates
---
brainpy/__init__.py | 3 +-
brainpy/_add_deprecations.py | 15 +--
brainpy/_src/dnn/__init__.py | 1 +
brainpy/_src/dnn/activations.py | 56 ++++-----
brainpy/_src/dnn/base.py | 14 +++
brainpy/_src/dnn/conv.py | 6 +-
brainpy/_src/dnn/dropout.py | 6 +-
brainpy/_src/dnn/function.py | 8 +-
brainpy/_src/dnn/interoperation_flax.py | 4 +-
brainpy/_src/dnn/linear.py | 49 ++++----
brainpy/_src/dnn/normalization.py | 8 +-
brainpy/_src/dnn/nvar.py | 4 +-
brainpy/_src/dnn/pooling.py | 8 +-
brainpy/_src/dnn/reservoir.py | 4 +-
brainpy/_src/dnn/rnncells.py | 10 +-
brainpy/_src/dyn/neurons/hh.py | 11 +-
brainpy/_src/dyn/projections/__init__.py | 1 +
brainpy/_src/dyn/projections/conn.py | 106 +++++++++++++++++
brainpy/_src/dyn/rates/populations.py | 2 +-
brainpy/_src/dynold/synapses/base.py | 109 ++----------------
brainpy/_src/dynsys.py | 49 +++-----
brainpy/_src/mixin.py | 46 +++++---
brainpy/dnn/others.py | 3 +
brainpy/dyn/projections.py | 4 +
brainpy/synapses.py | 1 -
docs/index.rst | 3 +-
examples/dynamics_simulation/hh_model.py | 26 ++++-
.../dynamics_training/Song_2016_EI_RNN.py | 1 -
28 files changed, 308 insertions(+), 250 deletions(-)
create mode 100644 brainpy/_src/dnn/base.py
create mode 100644 brainpy/_src/dyn/projections/conn.py
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index 68e72c21c..90edaca3d 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -61,7 +61,6 @@
Network as Network,
Dynamic as Dynamic, # category
Projection as Projection,
- AnnLayer as AnnLayer,
)
DynamicalSystemNS = DynamicalSystem
@@ -133,7 +132,7 @@
'TensorCollector': ('brainpy.TensorCollector', 'brainpy.ArrayCollector', ArrayCollector),
'SynSTP': ('brainpy.SynSTP', 'brainpy.synapses.SynSTP', synapses.SynSTP),
'SynOut': ('brainpy.SynOut', 'brainpy.synapses.SynOut', synapses.SynOut),
- 'SynConn': ('brainpy.SynConn', 'brainpy.synapses.SynConn', synapses.SynConn),
+ 'SynConn': ('brainpy.SynConn', 'brainpy.dyn.SynConn', dyn.SynConn),
'TwoEndConn': ('brainpy.TwoEndConn', 'brainpy.synapses.TwoEndConn', synapses.TwoEndConn),
'CondNeuGroup': ('brainpy.CondNeuGroup', 'brainpy.syn.CondNeuGroup', dyn.CondNeuGroup),
}
diff --git a/brainpy/_add_deprecations.py b/brainpy/_add_deprecations.py
index bd397ba24..05398c45f 100644
--- a/brainpy/_add_deprecations.py
+++ b/brainpy/_add_deprecations.py
@@ -82,14 +82,15 @@
'Container': ('brainpy.dyn.Container', 'brainpy.DynSysGroup', DynSysGroup),
'Sequential': ('brainpy.dyn.Sequential', 'brainpy.Sequential', Sequential),
'Network': ('brainpy.dyn.Network', 'brainpy.Network', Network),
- 'NeuGroup': ('brainpy.dyn.NeuGroup', 'brainpy.NeuDyn', NeuDyn),
'Channel': ('brainpy.dyn.Channel', 'brainpy.IonChaDyn', IonChaDyn),
'DSRunner': ('brainpy.dyn.DSRunner', 'brainpy.DSRunner', DSRunner),
+ # neurons
+ 'NeuGroup': ('brainpy.dyn.NeuGroup', 'brainpy.dyn.NeuDyn', NeuDyn),
+
# synapses
- 'SynConn': ('brainpy.dyn.SynConn', 'brainpy.synapses.SynConn', synapses.SynConn),
- 'SynSTP': ('brainpy.dyn.SynSTP', 'brainpy.synapses.SynSTP', synapses.SynSTP),
'TwoEndConn': ('brainpy.dyn.TwoEndConn', 'brainpy.synapses.TwoEndConn', synapses.TwoEndConn),
+ 'SynSTP': ('brainpy.dyn.SynSTP', 'brainpy.synapses.SynSTP', synapses.SynSTP),
'DeltaSynapse': ('brainpy.dyn.DeltaSynapse', 'brainpy.synapses.Delta', synapses.DeltaSynapse),
'ExpCUBA': ('brainpy.dyn.ExpCUBA', 'brainpy.synapses.Exponential', synapses.ExpCUBA),
'ExpCOBA': ('brainpy.dyn.ExpCOBA', 'brainpy.synapses.Exponential', synapses.ExpCOBA),
@@ -101,10 +102,10 @@
dyn.__getattr__ = deprecation_getattr2('brainpy.dyn', dyn.__deprecations)
-dnn.__deprecations = {
- 'Layer': ('brainpy.dnn.Layer', 'brainpy.AnnLayer', AnnLayer),
-}
-dnn.__getattr__ = deprecation_getattr2('brainpy.dnn', dnn.__deprecations)
+# dnn.__deprecations = {
+# 'Layer': ('brainpy.dnn.Layer', 'brainpy.AnnLayer', AnnLayer),
+# }
+# dnn.__getattr__ = deprecation_getattr2('brainpy.dnn', dnn.__deprecations)
layers.__deprecations = {
diff --git a/brainpy/_src/dnn/__init__.py b/brainpy/_src/dnn/__init__.py
index 6fa1eb184..f4b5f62c0 100644
--- a/brainpy/_src/dnn/__init__.py
+++ b/brainpy/_src/dnn/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
+from .base import *
from .activations import *
from .dropout import *
from .nvar import *
diff --git a/brainpy/_src/dnn/activations.py b/brainpy/_src/dnn/activations.py
index d079e4421..532ae5444 100644
--- a/brainpy/_src/dnn/activations.py
+++ b/brainpy/_src/dnn/activations.py
@@ -1,7 +1,7 @@
from typing import Optional
from brainpy import math as bm
-from brainpy._src.dynsys import AnnLayer
+from brainpy._src.dnn.base import Layer
from brainpy.types import ArrayType
__all__ = [
@@ -21,7 +21,7 @@ def _inplace(inp, val, inplace):
return val
-class Threshold(AnnLayer):
+class Threshold(Layer):
r"""Thresholds each element of the input Tensor.
Threshold is defined as:
@@ -73,7 +73,7 @@ def extra_repr(self):
)
-class ReLU(AnnLayer):
+class ReLU(Layer):
r"""Applies the rectified linear unit function element-wise:
:math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
@@ -118,7 +118,7 @@ def extra_repr(self) -> str:
return inplace_str
-class RReLU(AnnLayer):
+class RReLU(Layer):
r"""Applies the randomized leaky rectified liner unit function, element-wise,
as described in the paper:
@@ -184,7 +184,7 @@ def extra_repr(self):
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
-class Hardtanh(AnnLayer):
+class Hardtanh(Layer):
r"""Applies the HardTanh function element-wise.
HardTanh is defined as:
@@ -275,7 +275,7 @@ def extra_repr(self) -> str:
return inplace_str
-class Sigmoid(AnnLayer):
+class Sigmoid(Layer):
r"""Applies the element-wise function:
.. math::
@@ -299,7 +299,7 @@ def update(self, input: ArrayType) -> ArrayType:
return bm.sigmoid(input)
-class Hardsigmoid(AnnLayer):
+class Hardsigmoid(Layer):
r"""Applies the Hardsigmoid function element-wise.
Hardsigmoid is defined as:
@@ -339,7 +339,7 @@ def update(self, input: ArrayType) -> ArrayType:
return _inplace(input, x, self.inplace)
-class Tanh(AnnLayer):
+class Tanh(Layer):
r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
Tanh is defined as:
@@ -364,7 +364,7 @@ def update(self, input: ArrayType) -> ArrayType:
return bm.tanh(input)
-class SiLU(AnnLayer):
+class SiLU(Layer):
r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
The SiLU function is also known as the swish function.
@@ -406,7 +406,7 @@ def extra_repr(self) -> str:
return inplace_str
-class Mish(AnnLayer):
+class Mish(Layer):
r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
@@ -443,7 +443,7 @@ def extra_repr(self) -> str:
return inplace_str
-class Hardswish(AnnLayer):
+class Hardswish(Layer):
r"""Applies the Hardswish function, element-wise, as described in the paper:
`Searching for MobileNetV3 `_.
@@ -483,7 +483,7 @@ def update(self, input: ArrayType) -> ArrayType:
return _inplace(input, bm.hard_swish(input), self.inplace)
-class ELU(AnnLayer):
+class ELU(Layer):
r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described
in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
Units (ELUs) `__.
@@ -529,7 +529,7 @@ def extra_repr(self) -> str:
return 'alpha={}{}'.format(self.alpha, inplace_str)
-class CELU(AnnLayer):
+class CELU(Layer):
r"""Applies the element-wise function:
.. math::
@@ -573,7 +573,7 @@ def extra_repr(self) -> str:
return 'alpha={}{}'.format(self.alpha, inplace_str)
-class SELU(AnnLayer):
+class SELU(Layer):
r"""Applied element-wise, as:
.. math::
@@ -616,7 +616,7 @@ def extra_repr(self) -> str:
return inplace_str
-class GLU(AnnLayer):
+class GLU(Layer):
r"""Applies the gated linear unit function
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
of the input matrices and :math:`b` is the second half.
@@ -651,7 +651,7 @@ def extra_repr(self) -> str:
return 'dim={}'.format(self.dim)
-class GELU(AnnLayer):
+class GELU(Layer):
r"""Applies the Gaussian Error Linear Units function:
.. math:: \text{GELU}(x) = x * \Phi(x)
@@ -692,7 +692,7 @@ def extra_repr(self) -> str:
return 'approximate={}'.format(repr(self.approximate))
-class Hardshrink(AnnLayer):
+class Hardshrink(Layer):
r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
Hardshrink is defined as:
@@ -734,7 +734,7 @@ def extra_repr(self) -> str:
return '{}'.format(self.lambd)
-class LeakyReLU(AnnLayer):
+class LeakyReLU(Layer):
r"""Applies the element-wise function:
.. math::
@@ -785,7 +785,7 @@ def extra_repr(self) -> str:
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
-class LogSigmoid(AnnLayer):
+class LogSigmoid(Layer):
r"""Applies the element-wise function:
.. math::
@@ -808,7 +808,7 @@ def update(self, input: ArrayType) -> ArrayType:
return bm.log_sigmoid(input)
-class Softplus(AnnLayer):
+class Softplus(Layer):
r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
\log(1 + \exp(\beta * x))` element-wise.
@@ -850,7 +850,7 @@ def extra_repr(self) -> str:
return 'beta={}, threshold={}'.format(self.beta, self.threshold)
-class Softshrink(AnnLayer):
+class Softshrink(Layer):
r"""Applies the soft shrinkage function elementwise:
.. math::
@@ -890,7 +890,7 @@ def extra_repr(self) -> str:
return str(self.lambd)
-class PReLU(AnnLayer):
+class PReLU(Layer):
r"""Applies the element-wise function:
.. math::
@@ -954,7 +954,7 @@ def extra_repr(self) -> str:
return 'num_parameters={}'.format(self.num_parameters)
-class Softsign(AnnLayer):
+class Softsign(Layer):
r"""Applies the element-wise function:
.. math::
@@ -977,7 +977,7 @@ def update(self, input: ArrayType) -> ArrayType:
return bm.soft_sign(input)
-class Tanhshrink(AnnLayer):
+class Tanhshrink(Layer):
r"""Applies the element-wise function:
.. math::
@@ -1000,7 +1000,7 @@ def update(self, input: ArrayType) -> ArrayType:
return bm.tanh_shrink(input)
-class Softmin(AnnLayer):
+class Softmin(Layer):
r"""Applies the Softmin function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
lie in the range `[0, 1]` and sum to 1.
@@ -1045,7 +1045,7 @@ def extra_repr(self):
return 'dim={dim}'.format(dim=self.dim)
-class Softmax(AnnLayer):
+class Softmax(Layer):
r"""Applies the Softmax function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
lie in the range [0,1] and sum to 1.
@@ -1099,7 +1099,7 @@ def extra_repr(self) -> str:
return 'dim={dim}'.format(dim=self.dim)
-class Softmax2d(AnnLayer):
+class Softmax2d(Layer):
r"""Applies SoftMax over features to each spatial location.
When given an image of ``Channels x Height x Width``, it will
@@ -1128,7 +1128,7 @@ def update(self, input: ArrayType) -> ArrayType:
return bm.softmax(input, -3)
-class LogSoftmax(AnnLayer):
+class LogSoftmax(Layer):
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
input Tensor. The LogSoftmax formulation can be simplified as:
diff --git a/brainpy/_src/dnn/base.py b/brainpy/_src/dnn/base.py
new file mode 100644
index 000000000..40665956c
--- /dev/null
+++ b/brainpy/_src/dnn/base.py
@@ -0,0 +1,14 @@
+from brainpy._src.dynsys import DynamicalSystem
+
+
+__all__ = [
+ 'Layer'
+]
+
+
+class Layer(DynamicalSystem):
+ """Base class for a layer of artificial neural network."""
+
+ def reset_state(self, *args, **kwargs):
+ pass
+
diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py
index daf85ad74..6f4964647 100644
--- a/brainpy/_src/dnn/conv.py
+++ b/brainpy/_src/dnn/conv.py
@@ -7,7 +7,7 @@
from brainpy import math as bm, tools
from brainpy._src.initialize import Initializer, XavierNormal, ZeroInit, parameter
from brainpy.types import ArrayType
-from brainpy._src.dynsys import AnnLayer
+from brainpy._src.dnn.base import Layer
__all__ = [
'Conv1d', 'Conv2d', 'Conv3d',
@@ -36,7 +36,7 @@ def to_dimension_numbers(num_spatial_dims: int,
out_spec=image_dn)
-class _GeneralConv(AnnLayer):
+class _GeneralConv(Layer):
"""Apply a convolution to the inputs.
Parameters
@@ -462,7 +462,7 @@ def _check_input_dim(self, x):
Conv3D = Conv3d
-class _GeneralConvTranspose(AnnLayer):
+class _GeneralConvTranspose(Layer):
supported_modes = (bm.TrainingMode, bm.BatchingMode)
def __init__(
diff --git a/brainpy/_src/dnn/dropout.py b/brainpy/_src/dnn/dropout.py
index c5583b67f..0ec7ad494 100644
--- a/brainpy/_src/dnn/dropout.py
+++ b/brainpy/_src/dnn/dropout.py
@@ -40,8 +40,10 @@ def __init__(
super(Dropout, self).__init__(mode=mode, name=name)
self.prob = check.is_float(prob, min_bound=0., max_bound=1.)
- def update(self, x):
- if share.load('fit'):
+ def update(self, x, fit: Optional[bool] = None):
+ if fit is None:
+ fit = share['fit']
+ if fit:
keep_mask = bm.random.bernoulli(self.prob, x.shape)
return bm.where(keep_mask, x / self.prob, 0.)
else:
diff --git a/brainpy/_src/dnn/function.py b/brainpy/_src/dnn/function.py
index 0223a387a..78a7253fc 100644
--- a/brainpy/_src/dnn/function.py
+++ b/brainpy/_src/dnn/function.py
@@ -5,7 +5,7 @@
import brainpy.math as bm
from brainpy import check
-from brainpy._src.dynsys import AnnLayer
+from brainpy._src.dnn.base import Layer
__all__ = [
'Activation',
@@ -14,7 +14,7 @@
]
-class Activation(AnnLayer):
+class Activation(Layer):
r"""Applies an activation function to the inputs
Parameters:
@@ -43,7 +43,7 @@ def update(self, *args, **kwargs):
return self.activate_fun(*args, **kwargs, **self.kwargs)
-class Flatten(AnnLayer):
+class Flatten(Layer):
r"""Flattens a contiguous range of dims into 2D or 1D.
Parameters:
@@ -69,7 +69,7 @@ def update(self, x):
return x.flatten()
-class FunAsLayer(AnnLayer):
+class FunAsLayer(Layer):
def __init__(
self,
fun: Callable,
diff --git a/brainpy/_src/dnn/interoperation_flax.py b/brainpy/_src/dnn/interoperation_flax.py
index 5765df8fa..09f03ac13 100644
--- a/brainpy/_src/dnn/interoperation_flax.py
+++ b/brainpy/_src/dnn/interoperation_flax.py
@@ -7,7 +7,7 @@
from brainpy import math as bm
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.context import share
-from brainpy._src.dynsys import AnnLayer
+from brainpy._src.dnn.base import Layer
try:
import flax # noqa
@@ -35,7 +35,7 @@ def _is_bp(a):
return isinstance(a, bm.Array)
-class FromFlax(AnnLayer):
+class FromFlax(Layer):
"""
Transform a Flax module as a BrainPy :py:class:`~.DynamicalSystem`.
diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py
index a34f148c2..ef7cc377f 100644
--- a/brainpy/_src/dnn/linear.py
+++ b/brainpy/_src/dnn/linear.py
@@ -14,7 +14,7 @@
from brainpy.errors import MathError
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.types import ArrayType, Sharding
-from brainpy._src.dynsys import AnnLayer
+from brainpy._src.dnn.base import Layer
__all__ = [
'Dense', 'Linear',
@@ -28,7 +28,7 @@
]
-class Dense(AnnLayer):
+class Dense(Layer):
r"""A linear transformation applied over the last dimension of the input.
Mathematically, this node can be defined as:
@@ -207,7 +207,7 @@ def offline_fit(self,
Linear = Dense
-class Identity(AnnLayer):
+class Identity(Layer):
r"""A placeholder identity operator that is argument-insensitive.
"""
@@ -218,7 +218,7 @@ def update(self, x):
return x
-class AllToAll(AnnLayer):
+class AllToAll(Layer):
"""Synaptic matrix multiplication with All2All connections.
Args:
@@ -281,7 +281,7 @@ def update(self, pre_val):
return post_val
-class OneToOne(AnnLayer):
+class OneToOne(Layer):
"""Synaptic matrix multiplication with One2One connection.
Args:
@@ -315,7 +315,7 @@ def update(self, pre_val):
return pre_val * self.weight
-class MaskedLinear(AnnLayer):
+class MaskedLinear(Layer):
r"""Synaptic matrix multiplication with masked dense computation.
It performs the computation of:
@@ -332,8 +332,9 @@ class MaskedLinear(AnnLayer):
>>> weight=0.1)
Args:
- mask: TwoEndConnector. The connection.
+ conn: TwoEndConnector. The connection.
weight: Synaptic weights. Can be a scalar, array, or callable function.
+ mask_fun: Masking function.
sharding: The sharding strategy.
mode: The synaptic computing mode.
name: The synapse model name.
@@ -341,20 +342,22 @@ class MaskedLinear(AnnLayer):
def __init__(
self,
- mask: connect.TwoEndConnector,
+ conn: connect.TwoEndConnector,
weight: Union[float, ArrayType, Callable],
+ mask_fun: Callable = Identity(),
sharding: Optional[Sharding] = None,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
):
super().__init__(name=name, mode=mode)
- assert isinstance(mask, connect.TwoEndConnector)
- self.conn = mask
+ assert isinstance(conn, connect.TwoEndConnector)
+ self.conn = conn
self.sharding = sharding
+ self.mask_fun = mask_fun
# weight
- weight = init.parameter(weight, (mask.pre_num, mask.post_num), sharding=sharding)
+ weight = init.parameter(weight, (conn.pre_num, conn.post_num), sharding=sharding)
if isinstance(self.mode, bm.TrainingMode):
weight = bm.TrainVar(weight)
self.weight = weight
@@ -363,10 +366,10 @@ def __init__(
self.mask = bm.sharding.partition(self.conn.require('conn_mat'), sharding=sharding)
def update(self, x):
- return x @ (self.weight * self.mask)
+ return x @ self.mask_fun(self.weight * self.mask)
-class CSRLinear(AnnLayer):
+class CSRLinear(Layer):
r"""Synaptic matrix multiplication with CSR sparse computation.
It performs the computation of:
@@ -435,7 +438,7 @@ def _batch_csrmv(self, x):
method=self.method)
-class CSCLinear(AnnLayer):
+class CSCLinear(Layer):
r"""Synaptic matrix multiplication with CSC sparse computation.
It performs the computation of:
@@ -470,7 +473,7 @@ def __init__(
self.sharding = sharding
-class EventCSRLinear(AnnLayer):
+class EventCSRLinear(Layer):
r"""Synaptic matrix multiplication with event CSR sparse computation.
It performs the computation of:
@@ -535,7 +538,7 @@ def _batch_csrmv(self, x):
transpose=self.transpose)
-class BcsrMM(AnnLayer):
+class BcsrMM(Layer):
r"""Synaptic matrix multiplication with BCSR sparse computation.
It performs the computation of:
@@ -570,7 +573,7 @@ def __init__(
self.sharding = sharding
-class BcscMM(AnnLayer):
+class BcscMM(Layer):
r"""Synaptic matrix multiplication with BCSC sparse computation.
It performs the computation of:
@@ -605,7 +608,7 @@ def __init__(
self.sharding = sharding
-class JitFPHomoLinear(AnnLayer):
+class JitFPHomoLinear(Layer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
@@ -684,7 +687,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)
-class JitFPUniformLinear(AnnLayer):
+class JitFPUniformLinear(Layer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
@@ -764,7 +767,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)
-class JitFPNormalLinear(AnnLayer):
+class JitFPNormalLinear(Layer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
@@ -844,7 +847,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)
-class EventJitFPHomoLinear(AnnLayer):
+class EventJitFPHomoLinear(Layer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
@@ -923,7 +926,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)
-class EventJitFPUniformLinear(AnnLayer):
+class EventJitFPUniformLinear(Layer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
@@ -1003,7 +1006,7 @@ def _batch_mv(self, x):
outdim_parallel=not self.atomic)
-class EventJitFPNormalLinear(AnnLayer):
+class EventJitFPNormalLinear(Layer):
r"""Synaptic matrix multiplication with the just-in-time connectivity.
It performs the computation of:
diff --git a/brainpy/_src/dnn/normalization.py b/brainpy/_src/dnn/normalization.py
index dad6dd841..8df9be62b 100644
--- a/brainpy/_src/dnn/normalization.py
+++ b/brainpy/_src/dnn/normalization.py
@@ -8,7 +8,7 @@
from brainpy import math as bm, check
from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter
from brainpy.types import ArrayType
-from brainpy._src.dynsys import AnnLayer
+from brainpy._src.dnn.base import Layer
__all__ = [
'BatchNorm1d',
@@ -32,7 +32,7 @@ def _square(x):
return lax.square(x)
-class BatchNorm(AnnLayer):
+class BatchNorm(Layer):
r"""Batch Normalization layer [1]_.
This layer aims to reduce the internal covariant shift of data. It
@@ -407,7 +407,7 @@ def _check_input_dim(self, x):
assert x.shape[-1] == self.num_features
-class LayerNorm(AnnLayer):
+class LayerNorm(Layer):
r"""Layer normalization (https://arxiv.org/abs/1607.06450).
.. math::
@@ -504,7 +504,7 @@ def update(self, x):
return out
-class GroupNorm(AnnLayer):
+class GroupNorm(Layer):
r"""Group normalization layer.
.. math::
diff --git a/brainpy/_src/dnn/nvar.py b/brainpy/_src/dnn/nvar.py
index 87029a45b..c980a524c 100644
--- a/brainpy/_src/dnn/nvar.py
+++ b/brainpy/_src/dnn/nvar.py
@@ -8,7 +8,7 @@
import brainpy.math as bm
from brainpy import check
-from brainpy._src.dynsys import AnnLayer
+from brainpy._src.dnn.base import Layer
__all__ = [
'NVAR'
@@ -34,7 +34,7 @@ def _comb(N, k):
return 0
-class NVAR(AnnLayer):
+class NVAR(Layer):
"""Nonlinear vector auto-regression (NVAR) node.
This class has the following features:
diff --git a/brainpy/_src/dnn/pooling.py b/brainpy/_src/dnn/pooling.py
index 148e8537e..ac49ab45b 100644
--- a/brainpy/_src/dnn/pooling.py
+++ b/brainpy/_src/dnn/pooling.py
@@ -7,7 +7,7 @@
import numpy as np
from brainpy import math as bm, check
-from brainpy._src.dynsys import AnnLayer
+from brainpy._src.dnn.base import Layer
__all__ = [
'MaxPool',
@@ -28,7 +28,7 @@
]
-class Pool(AnnLayer):
+class Pool(Layer):
"""Pooling functions are implemented using the ReduceWindow XLA op.
Parameters
@@ -285,7 +285,7 @@ def update(self, x):
return pooled / window_counts
-class _MaxPoolNd(AnnLayer):
+class _MaxPoolNd(Layer):
def __init__(
self,
init_value,
@@ -717,7 +717,7 @@ def _generate_vmap(fun: Callable, map_axes: List[int]):
return fun
-class AdaptivePool(AnnLayer):
+class AdaptivePool(Layer):
"""General N dimensional adaptive down-sampling to a target shape.
Parameters
diff --git a/brainpy/_src/dnn/reservoir.py b/brainpy/_src/dnn/reservoir.py
index e21605ac2..e092991e2 100644
--- a/brainpy/_src/dnn/reservoir.py
+++ b/brainpy/_src/dnn/reservoir.py
@@ -9,14 +9,14 @@
from brainpy import check
from brainpy.tools import to_size
from brainpy.types import ArrayType
-from brainpy._src.dynsys import AnnLayer
+from brainpy._src.dnn.base import Layer
__all__ = [
'Reservoir',
]
-class Reservoir(AnnLayer):
+class Reservoir(Layer):
r"""Reservoir node, a pool of leaky-integrator neurons
with random recurrent connections [1]_.
diff --git a/brainpy/_src/dnn/rnncells.py b/brainpy/_src/dnn/rnncells.py
index 0038e2d29..f74f4acc5 100644
--- a/brainpy/_src/dnn/rnncells.py
+++ b/brainpy/_src/dnn/rnncells.py
@@ -7,7 +7,7 @@
import brainpy.math as bm
from brainpy.math import activations
-from brainpy._src.dynsys import AnnLayer
+from brainpy._src.dnn.base import Layer
from brainpy.check import (is_integer,
is_initializer)
from brainpy.initialize import (XavierNormal,
@@ -27,7 +27,7 @@
]
-class RNNCell(AnnLayer):
+class RNNCell(Layer):
r"""Basic fully-connected RNN core.
Given :math:`x_t` and the previous hidden state :math:`h_{t-1}` the
@@ -125,7 +125,7 @@ def update(self, x):
return self.state.value
-class GRUCell(AnnLayer):
+class GRUCell(Layer):
r"""Gated Recurrent Unit.
The implementation is based on (Chung, et al., 2014) [1]_ with biases.
@@ -247,7 +247,7 @@ def update(self, x):
return self.state.value
-class LSTMCell(AnnLayer):
+class LSTMCell(Layer):
r"""Long short-term memory (LSTM) RNN core.
The implementation is based on (zaremba, et al., 2014) [1]_. Given
@@ -442,7 +442,7 @@ def __init__(self, *args, **kwargs):
super(LSTM, self).__init__(*args, **kwargs)
-class _ConvNDLSTMCell(AnnLayer):
+class _ConvNDLSTMCell(Layer):
r"""``num_spatial_dims``-D convolutional LSTM.
The implementation is based on :cite:`xingjian2015convolutional`.
diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py
index 482a3ac91..4f6e68d34 100644
--- a/brainpy/_src/dyn/neurons/hh.py
+++ b/brainpy/_src/dyn/neurons/hh.py
@@ -16,6 +16,7 @@
from brainpy.types import Shape
__all__ = [
+ 'HHTypedNeuron',
'CondNeuGroupLTC',
'CondNeuGroup',
'HHLTC',
@@ -27,11 +28,11 @@
]
-class HHTypedNeuron(NeuDyn, Container, TreeNode):
- master_type = DynamicalSystem
+class HHTypedNeuron(NeuDyn):
+ pass
-class CondNeuGroupLTC(HHTypedNeuron):
+class CondNeuGroupLTC(HHTypedNeuron, Container, TreeNode):
r"""Base class to model conductance-based neuron group.
The standard formulation for a conductance-based model is given as
@@ -149,7 +150,7 @@ def update(self, x=None):
# update channels
for node in channels.values():
- node.update(self.V.value)
+ node(self.V.value)
# update variables
if self.spike.dtype == bool:
@@ -157,7 +158,7 @@ def update(self, x=None):
else:
self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th).astype(self.spike.dtype)
self.V.value = V
- return self.spike
+ return self.spike.value
def clear_input(self):
"""Useful for monitoring inputs. """
diff --git a/brainpy/_src/dyn/projections/__init__.py b/brainpy/_src/dyn/projections/__init__.py
index e58f35554..3efded3a6 100644
--- a/brainpy/_src/dyn/projections/__init__.py
+++ b/brainpy/_src/dyn/projections/__init__.py
@@ -1,3 +1,4 @@
from .aligns import *
+from .conn import *
from .others import *
diff --git a/brainpy/_src/dyn/projections/conn.py b/brainpy/_src/dyn/projections/conn.py
new file mode 100644
index 000000000..297b3bc98
--- /dev/null
+++ b/brainpy/_src/dyn/projections/conn.py
@@ -0,0 +1,106 @@
+from typing import Union, Dict, Optional
+
+import jax
+import numpy as np
+
+from brainpy import math as bm
+from brainpy._src.connect import TwoEndConnector, MatConn, IJConn
+from brainpy._src.dynsys import Projection, DynamicalSystem
+from brainpy.types import ArrayType
+
+__all__ = [
+ 'SynConn',
+]
+
+
+class SynConn(Projection):
+ """Base class to model two-end synaptic connections.
+
+ Parameters
+ ----------
+ pre : NeuGroup
+ Pre-synaptic neuron group.
+ post : NeuGroup
+ Post-synaptic neuron group.
+ conn : optional, ndarray, ArrayType, dict, TwoEndConnector
+ The connection method between pre- and post-synaptic groups.
+ name : str, optional
+ The name of the dynamic system.
+ """
+
+ def __init__(
+ self,
+ pre: DynamicalSystem,
+ post: DynamicalSystem,
+ conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # pre or post neuron group
+ # ------------------------
+ if not isinstance(pre, DynamicalSystem):
+ raise TypeError('"pre" must be an instance of DynamicalSystem.')
+ if not isinstance(post, DynamicalSystem):
+ raise TypeError('"post" must be an instance of DynamicalSystem.')
+ self.pre = pre
+ self.post = post
+
+ # connectivity
+ # ------------
+ if isinstance(conn, TwoEndConnector):
+ self.conn = conn(pre.size, post.size)
+ elif isinstance(conn, (bm.Array, np.ndarray, jax.Array)):
+ if (pre.num, post.num) != conn.shape:
+ raise ValueError(f'"conn" is provided as a matrix, and it is expected '
+ f'to be an array with shape of (pre.num, post.num) = '
+ f'{(pre.num, post.num)}, however we got {conn.shape}')
+ self.conn = MatConn(conn_mat=conn)
+ elif isinstance(conn, dict):
+ if not ('i' in conn and 'j' in conn):
+ raise ValueError(f'"conn" is provided as a dict, and it is expected to '
+ f'be a dictionary with "i" and "j" specification, '
+ f'however we got {conn}')
+ self.conn = IJConn(i=conn['i'], j=conn['j'])
+ elif isinstance(conn, str):
+ self.conn = conn
+ elif conn is None:
+ self.conn = None
+ else:
+ raise ValueError(f'Unknown "conn" type: {conn}')
+
+ def __repr__(self):
+ names = self.__class__.__name__
+ return (f'{names}(name={self.name}, mode={self.mode}, \n'
+ f'{" " * len(names)} pre={self.pre}, \n'
+ f'{" " * len(names)} post={self.post})')
+
+ def check_pre_attrs(self, *attrs):
+ """Check whether pre group satisfies the requirement."""
+ if not hasattr(self, 'pre'):
+ raise ValueError('Please call __init__ function first.')
+ for attr in attrs:
+ if not isinstance(attr, str):
+ raise TypeError(f'Must be string. But got {attr}.')
+ if not hasattr(self.pre, attr):
+ raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".')
+
+ def check_post_attrs(self, *attrs):
+ """Check whether post group satisfies the requirement."""
+ if not hasattr(self, 'post'):
+ raise ValueError('Please call __init__ function first.')
+ for attr in attrs:
+ if not isinstance(attr, str):
+ raise TypeError(f'Must be string. But got {attr}.')
+ if not hasattr(self.post, attr):
+ raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".')
+
+ def update(self, *args, **kwargs):
+ """The function to specify the updating rule.
+
+ Assume any dynamical system depends on the shared variables (`sha`),
+ like time variable ``t``, the step precision ``dt``, and the time step `i`.
+ """
+ raise NotImplementedError('Must implement "update" function by subclass self.')
+
diff --git a/brainpy/_src/dyn/rates/populations.py b/brainpy/_src/dyn/rates/populations.py
index 9ce83e144..8e91ecd11 100644
--- a/brainpy/_src/dyn/rates/populations.py
+++ b/brainpy/_src/dyn/rates/populations.py
@@ -99,7 +99,7 @@ def __init__(
mode: bm.Mode = None,
input_var: bool = True,
):
- super(FHN, self).__init__(size=size,
+ super().__init__(size=size,
name=name,
keep_size=keep_size,
mode=mode)
diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py
index bf14cbae0..ac84ed797 100644
--- a/brainpy/_src/dynold/synapses/base.py
+++ b/brainpy/_src/dynold/synapses/base.py
@@ -1,14 +1,13 @@
from typing import Union, Dict, Callable, Optional, Tuple
import jax
-import numpy as np
from brainpy import math as bm
-from brainpy._src.connect import TwoEndConnector, MatConn, IJConn, One2One, All2All
+from brainpy._src.connect import TwoEndConnector, One2One, All2All
from brainpy._src.dnn import linear
from brainpy._src.dyn import projections
-from brainpy._src.dynsys import Projection, DynamicalSystem
from brainpy._src.dyn.base import NeuDyn
+from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.initialize import parameter
from brainpy._src.mixin import (ParamDesc, ParamDescInit, JointType,
AutoDelaySupp, BindCondData, AlignPost,
@@ -17,7 +16,6 @@
from brainpy.types import ArrayType
__all__ = [
- 'SynConn',
'_SynSTP',
'_SynOut',
'TwoEndConn',
@@ -26,97 +24,6 @@
]
-class SynConn(Projection):
- """Base class to model two-end synaptic connections.
-
- Parameters
- ----------
- pre : NeuGroup
- Pre-synaptic neuron group.
- post : NeuGroup
- Post-synaptic neuron group.
- conn : optional, ndarray, ArrayType, dict, TwoEndConnector
- The connection method between pre- and post-synaptic groups.
- name : str, optional
- The name of the dynamic system.
- """
-
- def __init__(
- self,
- pre: DynamicalSystem,
- post: DynamicalSystem,
- conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]] = None,
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
- ):
- super().__init__(name=name, mode=mode)
-
- # pre or post neuron group
- # ------------------------
- if not isinstance(pre, DynamicalSystem):
- raise TypeError('"pre" must be an instance of DynamicalSystem.')
- if not isinstance(post, DynamicalSystem):
- raise TypeError('"post" must be an instance of DynamicalSystem.')
- self.pre = pre
- self.post = post
-
- # connectivity
- # ------------
- if isinstance(conn, TwoEndConnector):
- self.conn = conn(pre.size, post.size)
- elif isinstance(conn, (bm.Array, np.ndarray, jax.Array)):
- if (pre.num, post.num) != conn.shape:
- raise ValueError(f'"conn" is provided as a matrix, and it is expected '
- f'to be an array with shape of (pre.num, post.num) = '
- f'{(pre.num, post.num)}, however we got {conn.shape}')
- self.conn = MatConn(conn_mat=conn)
- elif isinstance(conn, dict):
- if not ('i' in conn and 'j' in conn):
- raise ValueError(f'"conn" is provided as a dict, and it is expected to '
- f'be a dictionary with "i" and "j" specification, '
- f'however we got {conn}')
- self.conn = IJConn(i=conn['i'], j=conn['j'])
- elif isinstance(conn, str):
- self.conn = conn
- elif conn is None:
- self.conn = None
- else:
- raise ValueError(f'Unknown "conn" type: {conn}')
-
- def __repr__(self):
- names = self.__class__.__name__
- return (f'{names}(name={self.name}, mode={self.mode}, \n'
- f'{" " * len(names)} pre={self.pre}, \n'
- f'{" " * len(names)} post={self.post})')
-
- def check_pre_attrs(self, *attrs):
- """Check whether pre group satisfies the requirement."""
- if not hasattr(self, 'pre'):
- raise ValueError('Please call __init__ function first.')
- for attr in attrs:
- if not isinstance(attr, str):
- raise TypeError(f'Must be string. But got {attr}.')
- if not hasattr(self.pre, attr):
- raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".')
-
- def check_post_attrs(self, *attrs):
- """Check whether post group satisfies the requirement."""
- if not hasattr(self, 'post'):
- raise ValueError('Please call __init__ function first.')
- for attr in attrs:
- if not isinstance(attr, str):
- raise TypeError(f'Must be string. But got {attr}.')
- if not hasattr(self.post, attr):
- raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".')
-
- def update(self, *args, **kwargs):
- """The function to specify the updating rule.
-
- Assume any dynamical system depends on the shared variables (`sha`),
- like time variable ``t``, the step precision ``dt``, and the time step `i`.
- """
- raise NotImplementedError('Must implement "update" function by subclass self.')
-
class _SynapseComponent(DynamicalSystem):
"""Base class for modeling synaptic components,
@@ -124,7 +31,7 @@ class _SynapseComponent(DynamicalSystem):
synaptic long-term plasticity, and others. """
'''Master of this component.'''
- master: SynConn
+ master: projections.SynConn
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -145,9 +52,9 @@ def isregistered(self, val: bool):
def reset_state(self, batch_size=None):
pass
- def register_master(self, master: SynConn):
- if not isinstance(master, SynConn):
- raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}')
+ def register_master(self, master: projections.SynConn):
+ if not isinstance(master, projections.SynConn):
+ raise TypeError(f'master must be instance of {projections.SynConn.__name__}, but we got {type(master)}')
if self.isregistered:
raise ValueError(f'master has been registered, but we got another master going to be registered.')
if hasattr(self, 'master') and self.master != master:
@@ -185,7 +92,7 @@ def __init__(
f'But we got {type(target_var)}')
self.target_var: Optional[bm.Variable] = target_var
- def register_master(self, master: SynConn):
+ def register_master(self, master: projections.SynConn):
super().register_master(master)
# initialize target variable to output
@@ -220,7 +127,7 @@ def clone(self):
return _NullSynOut()
-class TwoEndConn(SynConn):
+class TwoEndConn(projections.SynConn):
"""Base class to model synaptic connections.
Parameters
diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py
index 8a096ddf9..861b679a0 100644
--- a/brainpy/_src/dynsys.py
+++ b/brainpy/_src/dynsys.py
@@ -23,7 +23,7 @@
'DynSysGroup', 'Network', 'Sequential',
# category
- 'Dynamic', 'Projection', 'AnnLayer',
+ 'Dynamic', 'Projection',
]
SLICE_VARS = 'slice_vars'
@@ -322,26 +322,6 @@ def __init__(
self.children = bm.node_dict(self.format_elements(child_type, *children_as_tuple, **children_as_dict))
- def update(self):
- """Update function of a container.
-
- In this update function, the update functions in children systems are
- iteratively called.
- """
- for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values():
- node()
-
- def clear_input(self):
- """Clear inputs in the children classes."""
- for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values():
- node.clear_input()
-
-
-class Network(DynSysGroup):
- """A group of :py:class:`~.DynamicalSystem`s which defines the nodes and edges in a network.
- """
-
- @not_pass_shared
def update(self, *args, **kwargs):
"""Step function of a network.
@@ -365,18 +345,30 @@ def update(self, *args, **kwargs):
def reset_state(self, batch_size=None):
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().not_subset(DynView)
- # reset dynamics
- for node in nodes.subset(Dynamic).values():
- node.reset_state(batch_size)
-
# reset projections
for node in nodes.subset(Projection).values():
node.reset_state(batch_size)
+ # reset dynamics
+ for node in nodes.subset(Dynamic).values():
+ node.reset_state(batch_size)
+
# reset other types of nodes, including delays, ...
for node in nodes.not_subset(Dynamic).not_subset(Projection).values():
node.reset_state(batch_size)
+ def clear_input(self):
+ """Clear inputs in the children classes."""
+ nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().not_subset(DynView)
+ for node in nodes.values():
+ node.clear_input()
+
+
+class Network(DynSysGroup):
+ """A group of :py:class:`~.DynamicalSystem`s which defines the nodes and edges in a network.
+ """
+ pass
+
class Sequential(DynamicalSystem, AutoDelaySupp):
"""A sequential `input-output` module.
@@ -626,13 +618,6 @@ def __getitem__(self, item):
return DynView(target=self, index=item)
-class AnnLayer(DynamicalSystem):
- """Base class for a layer of artificial neural network."""
-
- def reset_state(self, *args, **kwargs):
- pass
-
-
class DynView(Dynamic):
"""DSView, an object used to get a view of a dynamical system instance.
diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py
index 547529076..8447e32e7 100644
--- a/brainpy/_src/mixin.py
+++ b/brainpy/_src/mixin.py
@@ -60,30 +60,30 @@ class ParamDescInit(object):
"""Delayed initialization for parameter describers.
"""
- def __init__(self, cls: type, *args, **kwargs):
+ def __init__(self, cls: type, *desc_tuple, **desc_dict):
self.cls = cls
# arguments
- self.args = args
- self.kwargs = kwargs
+ self.args = desc_tuple
+ self.kwargs = desc_dict
# identifier
if isinstance(cls, _JointGenericAlias):
name = str(cls)
- repr_kwargs = {k: v for k, v in kwargs.items()}
+ repr_kwargs = {k: v for k, v in desc_dict.items()}
else:
assert isinstance(cls, type)
if issubclass(cls, ParamDesc) and (cls.not_desc_params is not None):
- repr_kwargs = {k: v for k, v in kwargs.items() if k not in cls.not_desc_params}
+ repr_kwargs = {k: v for k, v in desc_dict.items() if k not in cls.not_desc_params}
else:
- repr_kwargs = {k: v for k, v in kwargs.items()}
+ repr_kwargs = {k: v for k, v in desc_dict.items()}
name = cls.__name__
for k in tuple(repr_kwargs.keys()):
if isinstance(repr_kwargs[k], bm.Variable):
repr_kwargs[k] = id(repr_kwargs[k])
repr_args = tools.repr_dict(repr_kwargs)
- if len(args):
- repr_args = f"{', '.join([repr(arg) for arg in args])}, {repr_args}"
+ if len(desc_tuple):
+ repr_args = f"{', '.join([repr(arg) for arg in desc_tuple])}, {repr_args}"
self._identifier = f'{name}({repr_args})'
def __call__(self, *args, **kwargs):
@@ -197,43 +197,53 @@ def format_elements(self, child_type: type, *children_as_tuple, **children_as_di
res[k] = v
return res
+ def add_elem(self, **elements):
+ """Add new elements.
+
+ >>> obj = Container()
+ >>> obj.add_elem(1.)
+
+ Args:
+ elements: children objects.
+ """
+ self.check_hierarchies(type(self), **elements)
+ self.children.update(self.format_elements(object, **elements))
+
class TreeNode(MixIn):
"""Tree node. """
master_type: type
- @staticmethod
- def check_hierarchies(root, *leaves, **named_leaves):
+ def check_hierarchies(self, root, *leaves, **named_leaves):
global DynamicalSystem
if DynamicalSystem is None:
from brainpy._src.dynsys import DynamicalSystem
for leaf in leaves:
if isinstance(leaf, DynamicalSystem):
- TreeNode.check_hierarchy(root, leaf)
+ self.check_hierarchy(root, leaf)
elif isinstance(leaf, (list, tuple)):
- TreeNode.check_hierarchies(root, *leaf)
+ self.check_hierarchies(root, *leaf)
elif isinstance(leaf, dict):
- TreeNode.check_hierarchies(root, **leaf)
+ self.check_hierarchies(root, **leaf)
else:
raise ValueError(f'Do not support {type(leaf)}.')
for leaf in named_leaves.values():
if not isinstance(leaf, DynamicalSystem):
raise ValueError(f'Do not support {type(leaf)}. Must be instance of {DynamicalSystem.__name__}')
- TreeNode.check_hierarchy(root, leaf)
+ self.check_hierarchy(root, leaf)
- @staticmethod
- def check_hierarchy(root, leaf):
+ def check_hierarchy(self, root, leaf):
if hasattr(leaf, 'master_type'):
master_type = leaf.master_type
else:
- raise ValueError('Child class should define "root_type" to '
+ raise ValueError('Child class should define "master_type" to '
'specify the type of the root node. '
f'But we did not found it in {leaf}')
if not issubclass(root, master_type):
raise TypeError(f'Type does not match. {leaf} requires a master with type '
- f'of {leaf.master_type}, but the master now is {leaf}.')
+ f'of {leaf.master_type}, but the master now is {root}.')
class DelayRegister(MixIn):
diff --git a/brainpy/dnn/others.py b/brainpy/dnn/others.py
index 958c155a1..7bd47b928 100644
--- a/brainpy/dnn/others.py
+++ b/brainpy/dnn/others.py
@@ -1,5 +1,8 @@
+from brainpy._src.dnn.base import (
+ Layer as Layer,
+)
from brainpy._src.dnn.dropout import (
Dropout as Dropout,
)
diff --git a/brainpy/dyn/projections.py b/brainpy/dyn/projections.py
index 15dde3d57..a09617988 100644
--- a/brainpy/dyn/projections.py
+++ b/brainpy/dyn/projections.py
@@ -5,6 +5,10 @@
ProjAlignPre as ProjAlignPre,
)
+from brainpy._src.dyn.projections.conn import (
+ SynConn as SynConn,
+)
+
from brainpy._src.dyn.projections.others import (
PoissonInput as PoissonInput,
)
diff --git a/brainpy/synapses.py b/brainpy/synapses.py
index d07fb1954..266ebf280 100644
--- a/brainpy/synapses.py
+++ b/brainpy/synapses.py
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
from brainpy._src.dynold.synapses.base import (
- SynConn as SynConn,
_SynSTP as SynSTP,
_SynOut as SynOut,
TwoEndConn as TwoEndConn,
diff --git a/docs/index.rst b/docs/index.rst
index bf1a38560..cf5b06e87 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -114,9 +114,10 @@ general-purpose Brain Dynamics Programming (BDP). Among its key ingredients, Bra
apis/auto/changelog.rst
+The following APIs will be no longer supported.
+
.. toctree::
:maxdepth: 1
- :caption: Deprecated APIs
apis/channels.rst
apis/neurons.rst
diff --git a/examples/dynamics_simulation/hh_model.py b/examples/dynamics_simulation/hh_model.py
index 5adb2fd4a..06b435595 100644
--- a/examples/dynamics_simulation/hh_model.py
+++ b/examples/dynamics_simulation/hh_model.py
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
+
import numpy as np
import brainpy as bp
@@ -8,15 +9,36 @@
bm.set_host_device_count(20)
-class HH(bp.CondNeuGroup):
+class HH(bp.dyn.CondNeuGroup):
def __init__(self, size):
- super(HH, self).__init__(size, keep_size=True)
+ super().__init__(size, keep_size=True)
self.INa = bp.channels.INa_HH1952(size, keep_size=True)
self.IK = bp.channels.IK_HH1952(size, keep_size=True)
self.IL = bp.channels.IL(size, E=-54.387, g_max=0.03, keep_size=True)
+class HHv2(bp.dyn.CondNeuGroupLTC):
+ def __init__(self, size):
+ super().__init__(size, keep_size=True)
+
+ self.Na = bp.dyn.SodiumFixed(size, E=50.)
+ self.Na.add(ina=bp.dyn.INa_HH1952v2(size, keep_size=True))
+
+ self.K = bp.dyn.PotassiumFixed(size, E=50.)
+ self.K.add(ik=bp.dyn.IK_HH1952v2(size, keep_size=True))
+
+ self.IL = bp.dyn.IL(size, E=-54.387, g_max=0.03, keep_size=True)
+
+ self.KNa = bp.dyn.mixs(self.Na, self.K)
+ self.KNa.add()
+
+
+
+
+
+
+
# hh = HH(1)
# I, length = bp.inputs.section_input(values=[0, 5, 0],
# durations=[100, 500, 100],
diff --git a/examples/dynamics_training/Song_2016_EI_RNN.py b/examples/dynamics_training/Song_2016_EI_RNN.py
index 0df5f9409..e4a19ba7b 100644
--- a/examples/dynamics_training/Song_2016_EI_RNN.py
+++ b/examples/dynamics_training/Song_2016_EI_RNN.py
@@ -72,7 +72,6 @@ def cell(self, x, h):
def readout(self, h):
return h @ self.w_ro + self.b_ro
- @bp.not_pass_shared
def update(self, x):
self.h.value = self.cell(x, self.h)
self.o.value = self.readout(self.h[:, :self.e_size])
From 44df96a6aacd382006b10dfe7c996409232f0c54 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Tue, 11 Jul 2023 21:37:47 +0800
Subject: [PATCH 029/326] new api doc
---
docs/index.rst | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/index.rst b/docs/index.rst
index cf5b06e87..fbc773668 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -114,7 +114,7 @@ general-purpose Brain Dynamics Programming (BDP). Among its key ingredients, Bra
apis/auto/changelog.rst
-The following APIs will be no longer supported.
+The following APIs will no longer be maintained in the future, but you can still use them normally.
.. toctree::
:maxdepth: 1
From fab256ab507e120865727a9259dd2a894f3730ab Mon Sep 17 00:00:00 2001
From: chaoming
Date: Tue, 11 Jul 2023 21:56:04 +0800
Subject: [PATCH 030/326] fix test bugs
---
brainpy/_add_deprecations.py | 10 +-
.../connect/tests/test_optimized_result.py | 382 +++++++++---------
brainpy/_src/dnn/dropout.py | 4 +-
brainpy/_src/losses/base.py | 4 +-
brainpy/_src/tests/test_mixin.py | 2 +-
brainpy/dyn/channels.py | 1 -
6 files changed, 204 insertions(+), 199 deletions(-)
diff --git a/brainpy/_add_deprecations.py b/brainpy/_add_deprecations.py
index 05398c45f..0b782d3cf 100644
--- a/brainpy/_add_deprecations.py
+++ b/brainpy/_add_deprecations.py
@@ -8,7 +8,7 @@
from brainpy._src.integrators.ode.generic import odeint
from brainpy._src.integrators.sde.generic import sdeint
from brainpy._src.integrators.fde.generic import fdeint
-from brainpy._src.dynsys import (DynamicalSystem, DynSysGroup, Sequential, Network, AnnLayer)
+from brainpy._src.dynsys import (DynamicalSystem, DynSysGroup, Sequential, Network)
from brainpy._src.dyn.base import NeuDyn, IonChaDyn
from brainpy._src.runners import DSRunner
from brainpy._src.deprecations import deprecation_getattr2
@@ -108,9 +108,9 @@
# dnn.__getattr__ = deprecation_getattr2('brainpy.dnn', dnn.__deprecations)
-layers.__deprecations = {
- 'Layer': ('brainpy.layers.Layer', 'brainpy.AnnLayer', AnnLayer),
-}
-layers.__getattr__ = deprecation_getattr2('brainpy.layers', layers.__deprecations)
+# layers.__deprecations = {
+# 'Layer': ('brainpy.layers.Layer', 'brainpy.AnnLayer', AnnLayer),
+# }
+# layers.__getattr__ = deprecation_getattr2('brainpy.layers', layers.__deprecations)
diff --git a/brainpy/_src/connect/tests/test_optimized_result.py b/brainpy/_src/connect/tests/test_optimized_result.py
index 7afd03136..6eb4d5f2a 100644
--- a/brainpy/_src/connect/tests/test_optimized_result.py
+++ b/brainpy/_src/connect/tests/test_optimized_result.py
@@ -4,234 +4,240 @@
import pytest
import unittest
+import pytest
import brainpy as bp
from time import time
try:
- import pandas as pd
+ import pandas as pd
- df = pd.DataFrame(
- columns=['connector name', 'connect matrix size', 'build function', 'other parameter', 'time origin(ms)',
- 'time optimized(ms)'])
+ df = pd.DataFrame(
+ columns=['connector name', 'connect matrix size',
+ 'build function', 'other parameter',
+ 'time origin(ms)', 'time optimized(ms)'])
except (ImportError, ModuleNotFoundError):
- print('No pandas installed, skip test.')
+ pytest.skip('No pandas installed, skip test.', allow_module_level=True)
# size_same = [100, 500, 2500, 12500, 25000, 37500, 50000]
# size_same = [100, 500, 2500, 12500]
size_same = [100, 500, 2500]
+
def get_ms(value):
- return round(value * 1000, 4)
+ return round(value * 1000, 4)
-def insert_row(connector_name, connect_matrix_size, build_function, other_parameter, time_origin_used,
- time_optimized_used):
- try:
- df.loc[len(df)] = [connector_name, connect_matrix_size, build_function, other_parameter, time_origin_used, time_optimized_used]
- except (NameError, UnboundLocalError):
- print('No pandas installed, skip test.')
+def insert_row(connector_name, connect_matrix_size,
+ build_function, other_parameter,
+ time_origin_used, time_optimized_used):
+ try:
+ df.loc[len(df)] = [connector_name, connect_matrix_size,
+ build_function, other_parameter,
+ time_origin_used, time_optimized_used]
+ except (NameError, UnboundLocalError):
+ print('No pandas installed, skip test.')
def test_GaussianProb1():
- conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)
- for size in size_same:
- conn(pre_size=size)
- mat = conn.build_mat(isOptimized=True)
- time0 = time()
- mat1 = conn.build_mat(isOptimized=True)
- time_optimized = get_ms(time() - time0)
-
- mat2 = conn.build_mat(isOptimized=False)
- time0 = time()
- mat2 = conn.build_mat(isOptimized=False)
- time_origin = get_ms(time() - time0)
-
- assert bp.math.array_equal(mat1, mat2)
- print()
- print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
- insert_row('GaussianProb',
- f'{size}x{size}',
- 'build_mat',
- 'sigma=1 / include_self=False',
- time_origin,
- time_optimized)
+ conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)
+ for size in size_same:
+ conn(pre_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ print()
+ print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
+ insert_row('GaussianProb',
+ f'{size}x{size}',
+ 'build_mat',
+ 'sigma=1 / include_self=False',
+ time_origin,
+ time_optimized)
def test_GaussianProb2():
- conn = bp.connect.GaussianProb(sigma=4, seed=123)
- for size in size_same:
- conn(pre_size=size)
- mat = conn.build_mat(isOptimized=True)
- time0 = time()
- mat1 = conn.build_mat(isOptimized=True)
- time_optimized = get_ms(time() - time0)
-
- mat2 = conn.build_mat(isOptimized=False)
- time0 = time()
- mat2 = conn.build_mat(isOptimized=False)
- time_origin = get_ms(time() - time0)
-
- assert bp.math.array_equal(mat1, mat2)
- print()
- print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
- insert_row('GaussianProb',
- f'{size}x{size}',
- 'build_mat',
- 'sigma=4',
- time_origin,
- time_optimized)
+ conn = bp.connect.GaussianProb(sigma=4, seed=123)
+ for size in size_same:
+ conn(pre_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ print()
+ print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
+ insert_row('GaussianProb',
+ f'{size}x{size}',
+ 'build_mat',
+ 'sigma=4',
+ time_origin,
+ time_optimized)
def test_GaussianProb3():
- conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123)
- for size in size_same:
- conn(pre_size=size)
- mat = conn.build_mat(isOptimized=True)
- time0 = time()
- mat1 = conn.build_mat(isOptimized=True)
- time_optimized = get_ms(time() - time0)
-
- mat2 = conn.build_mat(isOptimized=False)
- time0 = time()
- mat2 = conn.build_mat(isOptimized=False)
- time_origin = get_ms(time() - time0)
-
- assert bp.math.array_equal(mat1, mat2)
- print()
- print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
- insert_row('GaussianProb',
- f'{size}x{size}',
- 'build_mat',
- 'sigma=4 / periodic_boundary=True',
- time_origin,
- time_optimized)
+ conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123)
+ for size in size_same:
+ conn(pre_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ print()
+ print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
+ insert_row('GaussianProb',
+ f'{size}x{size}',
+ 'build_mat',
+ 'sigma=4 / periodic_boundary=True',
+ time_origin,
+ time_optimized)
def testGaussianProb4():
- conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123)
- for size in size_same:
- conn(pre_size=size)
- mat = conn.build_mat(isOptimized=True)
- time0 = time()
- mat1 = conn.build_mat(isOptimized=True)
- time_optimized = get_ms(time() - time0)
-
- mat2 = conn.build_mat(isOptimized=False)
- time0 = time()
- mat2 = conn.build_mat(isOptimized=False)
- time_origin = get_ms(time() - time0)
-
- assert bp.math.array_equal(mat1, mat2)
- print()
- print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
- insert_row('GaussianProb',
- f'{size}x{size}',
- 'build_mat',
- 'sigma=4 / periodic_boundary=True',
- time_origin,
- time_optimized)
+ conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123)
+ for size in size_same:
+ conn(pre_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ print()
+ print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}')
+ insert_row('GaussianProb',
+ f'{size}x{size}',
+ 'build_mat',
+ 'sigma=4 / periodic_boundary=True',
+ time_origin,
+ time_optimized)
def test_ScaleFreeBA():
- conn = bp.connect.ScaleFreeBA(m=2, seed=123)
- for size in size_same:
- conn(pre_size=size, post_size=size)
- mat = conn.build_mat(isOptimized=True)
- time0 = time()
- mat1 = conn.build_mat(isOptimized=True)
- time_optimized = get_ms(time() - time0)
-
- mat2 = conn.build_mat(isOptimized=False)
- time0 = time()
- mat2 = conn.build_mat(isOptimized=False)
- time_origin = get_ms(time() - time0)
-
- assert bp.math.array_equal(mat1, mat2)
- insert_row('ScaleFreeBA',
- f'{size}x{size}',
- 'build_mat',
- 'm=2',
- time_origin,
- time_optimized)
+ conn = bp.connect.ScaleFreeBA(m=2, seed=123)
+ for size in size_same:
+ conn(pre_size=size, post_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ insert_row('ScaleFreeBA',
+ f'{size}x{size}',
+ 'build_mat',
+ 'm=2',
+ time_origin,
+ time_optimized)
def test_ScaleFreeBADual():
- conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4, seed=123)
- for size in size_same:
- conn(pre_size=size, post_size=size)
- mat = conn.build_mat(isOptimized=True)
- time0 = time()
- mat1 = conn.build_mat(isOptimized=True)
- time_optimized = get_ms(time() - time0)
-
- mat2 = conn.build_mat(isOptimized=False)
- time0 = time()
- mat2 = conn.build_mat(isOptimized=False)
- time_origin = get_ms(time() - time0)
-
- assert bp.math.array_equal(mat1, mat2)
- insert_row('ScaleFreeBADual',
- f'{size}x{size}',
- 'build_mat',
- 'm1=2 / m2=3 / p=0.4',
- time_origin,
- time_optimized)
+ conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4, seed=123)
+ for size in size_same:
+ conn(pre_size=size, post_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ insert_row('ScaleFreeBADual',
+ f'{size}x{size}',
+ 'build_mat',
+ 'm1=2 / m2=3 / p=0.4',
+ time_origin,
+ time_optimized)
def test_PowerLaw():
- conn = bp.connect.PowerLaw(m=3, p=0.4, seed=123)
- for size in size_same:
- conn(pre_size=size, post_size=size)
- mat = conn.build_mat(isOptimized=True)
- time0 = time()
- mat1 = conn.build_mat(isOptimized=True)
- time_optimized = get_ms(time() - time0)
-
- mat2 = conn.build_mat(isOptimized=False)
- time0 = time()
- mat2 = conn.build_mat(isOptimized=False)
- time_origin = get_ms(time() - time0)
-
- assert bp.math.array_equal(mat1, mat2)
- insert_row('PowerLaw',
- f'{size}x{size}',
- 'build_mat',
- 'm=3 / p=0.4',
- time_origin,
- time_optimized)
+ conn = bp.connect.PowerLaw(m=3, p=0.4, seed=123)
+ for size in size_same:
+ conn(pre_size=size, post_size=size)
+ mat = conn.build_mat(isOptimized=True)
+ time0 = time()
+ mat1 = conn.build_mat(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ mat2 = conn.build_mat(isOptimized=False)
+ time0 = time()
+ mat2 = conn.build_mat(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ assert bp.math.array_equal(mat1, mat2)
+ insert_row('PowerLaw',
+ f'{size}x{size}',
+ 'build_mat',
+ 'm=3 / p=0.4',
+ time_origin,
+ time_optimized)
def test_ProbDist():
- conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=123, include_self=True)
- # for size in [1000, (100, 20), (4, 20, 20), (4, 3, 8, 5)]:
- for size in [10000]:
- conn(pre_size=size, post_size=size)
- pre_ids1, post_ids1 = conn.build_coo(isOptimized=True)
- time0 = time()
- pre_ids1, post_ids1 = conn.build_coo(isOptimized=True)
- time_optimized = get_ms(time() - time0)
-
- pre_ids2, post_ids2 = conn.build_coo(isOptimized=False)
- time0 = time()
- pre_ids2, post_ids2 = conn.build_coo(isOptimized=False)
- time_origin = get_ms(time() - time0)
-
- # assert (bp.math.array_equal(pre_ids1, pre_ids2) and bp.math.array_equal(post_ids1, post_ids2))
- print()
- print(f'time origin: {time_origin}\ntime optimized: {time_optimized}')
- insert_row('ProbDist',
- {size},
- 'build_coo',
- 'dist=1 / prob=0.5 / pre_ratio=0.3 / include_self=True',
- time_origin,
- time_optimized)
+ conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=123, include_self=True)
+ # for size in [1000, (100, 20), (4, 20, 20), (4, 3, 8, 5)]:
+ for size in [10000]:
+ conn(pre_size=size, post_size=size)
+ pre_ids1, post_ids1 = conn.build_coo(isOptimized=True)
+ time0 = time()
+ pre_ids1, post_ids1 = conn.build_coo(isOptimized=True)
+ time_optimized = get_ms(time() - time0)
+
+ pre_ids2, post_ids2 = conn.build_coo(isOptimized=False)
+ time0 = time()
+ pre_ids2, post_ids2 = conn.build_coo(isOptimized=False)
+ time_origin = get_ms(time() - time0)
+
+ # assert (bp.math.array_equal(pre_ids1, pre_ids2) and bp.math.array_equal(post_ids1, post_ids2))
+ print()
+ print(f'time origin: {time_origin}\ntime optimized: {time_optimized}')
+ insert_row('ProbDist',
+ {size},
+ 'build_coo',
+ 'dist=1 / prob=0.5 / pre_ratio=0.3 / include_self=True',
+ time_origin,
+ time_optimized)
def test_save():
- try:
- df.to_csv('opt_time_compare' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv',
- index=False)
- except (NameError, UnboundLocalError):
- print('No pandas installed, skip test.')
\ No newline at end of file
+ try:
+ df.to_csv('opt_time_compare' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv',
+ index=False)
+ except (NameError, UnboundLocalError):
+ print('No pandas installed, skip test.')
diff --git a/brainpy/_src/dnn/dropout.py b/brainpy/_src/dnn/dropout.py
index 0ec7ad494..6bd8bde7a 100644
--- a/brainpy/_src/dnn/dropout.py
+++ b/brainpy/_src/dnn/dropout.py
@@ -4,14 +4,14 @@
from brainpy._src.context import share
from brainpy import math as bm, check
-from brainpy._src.dynsys import AnnLayer
+from brainpy._src.dnn.base import Layer
__all__ = [
'Dropout'
]
-class Dropout(AnnLayer):
+class Dropout(Layer):
"""A layer that stochastically ignores a subset of inputs each training step.
In training, to compensate for the fraction of input values dropped (`rate`),
diff --git a/brainpy/_src/losses/base.py b/brainpy/_src/losses/base.py
index e1cfecf28..a01e2aee8 100644
--- a/brainpy/_src/losses/base.py
+++ b/brainpy/_src/losses/base.py
@@ -1,6 +1,6 @@
from typing import Optional
-from brainpy._src.dynsys import AnnLayer
+from brainpy._src.dnn.base import Layer
__all__ = [
'Loss',
@@ -8,7 +8,7 @@
]
-class Loss(AnnLayer):
+class Loss(Layer):
reduction: str
def __init__(self, reduction: str = 'mean') -> None:
diff --git a/brainpy/_src/tests/test_mixin.py b/brainpy/_src/tests/test_mixin.py
index 1352d47b7..1544a1f33 100644
--- a/brainpy/_src/tests/test_mixin.py
+++ b/brainpy/_src/tests/test_mixin.py
@@ -18,7 +18,7 @@ def test2(self):
class TestJointType(unittest.TestCase):
def test1(self):
T = bp.mixin.JointType[bp.DynamicalSystem]
- self.assertTrue(isinstance(bp.AnnLayer(), T))
+ self.assertTrue(isinstance(bp.dnn.Layer(), T))
T = bp.mixin.JointType[bp.DynamicalSystem, bp.mixin.ParamDesc]
self.assertTrue(isinstance(bp.dyn.Expon(1), T))
diff --git a/brainpy/dyn/channels.py b/brainpy/dyn/channels.py
index eff433df8..03d8e979f 100644
--- a/brainpy/dyn/channels.py
+++ b/brainpy/dyn/channels.py
@@ -39,7 +39,6 @@
from brainpy._src.dyn.channels.hyperpolarization_activated import (
- IhChannel,
Ih_HM1992,
Ih_De1996,
)
From dddcd92d5dcd72235f12172f91e9c1ef27521e66 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Tue, 11 Jul 2023 22:22:19 +0800
Subject: [PATCH 031/326] fix doc
---
docs/auto_generater.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/docs/auto_generater.py b/docs/auto_generater.py
index 3bca449e7..081c20821 100644
--- a/docs/auto_generater.py
+++ b/docs/auto_generater.py
@@ -510,7 +510,6 @@ def generate_brainpy_docs():
'Network',
'Dynamic',
'Projection',
- 'AnnLayer',
],
'Simulating Dynamical System': ['DSRunner'],
'Training Dynamical System': ['DSTrainer',
From fc69e4b84e2f506bf70bef51c1aefbd75e698b62 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Wed, 12 Jul 2023 22:47:14 +0800
Subject: [PATCH 032/326] add a new implementation of Dual Exponential Synapse
model which can be aligned post.
---
brainpy/_src/dyn/synapses/abstract_models.py | 91 ++++++++++++++++++--
brainpy/dyn/synapses.py | 1 +
2 files changed, 84 insertions(+), 8 deletions(-)
diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/dyn/synapses/abstract_models.py
index cd8162f58..81cf954d5 100644
--- a/brainpy/_src/dyn/synapses/abstract_models.py
+++ b/brainpy/_src/dyn/synapses/abstract_models.py
@@ -15,6 +15,7 @@
'Delta',
'Expon',
'DualExpon',
+ 'DualExponV2',
'Alpha',
'NMDA',
'STD',
@@ -154,11 +155,9 @@ def reset_state(self, batch_size=None):
self.g = self.init_variable(bm.zeros, batch_size)
def update(self, x=None):
- t = share.load('t')
- dt = share.load('dt')
- self.g.value = self.integral(self.g.value, t, dt)
+ self.g.value = self.integral(self.g.value, share['t'], share['dt'])
if x is not None:
- self.g.value += x
+ self.add_current(x)
return self.g.value
def add_current(self, x):
@@ -250,11 +249,8 @@ def dg(self, g, t, h):
return -g / self.tau_decay + h
def update(self, x):
- t = share.load('t')
- dt = share.load('dt')
-
# update synaptic variables
- self.g.value, self.h.value = self.integral(self.g.value, self.h.value, t, dt=dt)
+ self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt'])
self.h += x
return self.g.value
@@ -265,6 +261,85 @@ def return_info(self):
DualExpon.__doc__ = DualExpon.__doc__ % (pneu_doc,)
+class DualExponV2(SynDyn, AlignPost):
+ r"""Dual exponential synapse model.
+
+ The dual exponential synapse model [1]_, also named as *difference of two exponentials* model,
+ is given by:
+
+ .. math::
+
+ g_{\mathrm{syn}}(t)=g_{\mathrm{max}} \frac{\tau_{1} \tau_{2}}{
+ \tau_{1}-\tau_{2}}\left(\exp \left(-\frac{t-t_{0}}{\tau_{1}}\right)
+ -\exp \left(-\frac{t-t_{0}}{\tau_{2}}\right)\right)
+
+ where :math:`\tau_1` is the time constant of the decay phase, :math:`\tau_2`
+ is the time constant of the rise phase, :math:`t_0` is the time of the pre-synaptic
+ spike, :math:`g_{\mathrm{max}}` is the maximal conductance.
+
+ .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
+ "The Synapse." Principles of Computational Modelling in Neuroscience.
+ Cambridge: Cambridge UP, 2011. 172-95. Print.
+ .. [2] Roth, A., & Van Rossum, M. C. W. (2009). Modeling Synapses. Computational
+ Modeling Methods for Neuroscientists.
+
+ Args:
+ tau_decay: float, ArrayArray, Callable. The time constant of the synaptic decay phase. [ms]
+ tau_rise: float, ArrayArray, Callable. The time constant of the synaptic rise phase. [ms]
+ %s
+ """
+
+ def __init__(
+ self,
+ size: Union[int, Sequence[int]],
+ keep_size: bool = False,
+ sharding: Optional[Sequence[str]] = None,
+ method: str = 'exp_auto',
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+
+ # synapse parameters
+ tau_decay: Union[float, ArrayType, Callable] = 10.0,
+ tau_rise: Union[float, ArrayType, Callable] = 1.,
+ ):
+ super().__init__(name=name,
+ mode=mode,
+ size=size,
+ keep_size=keep_size,
+ sharding=sharding)
+
+ # parameters
+ self.tau_rise = self.init_param(tau_rise)
+ self.tau_decay = self.init_param(tau_decay)
+ self.coeff = self.tau_rise * self.tau_decay / (self.tau_decay - self.tau_rise)
+
+ # integrator
+ self.integral = odeint(lambda g, t, tau: -g / tau, method=method)
+
+ self.reset_state(self.mode)
+
+ def reset_state(self, batch_size=None):
+ self.g_rise = self.init_variable(bm.zeros, batch_size)
+ self.g_decay = self.init_variable(bm.zeros, batch_size)
+
+ def update(self, x=None):
+ self.g_rise.value = self.integral(self.g_rise.value, share['t'], self.tau_rise, share['dt'])
+ self.g_decay.value = self.integral(self.g_decay.value, share['t'], self.tau_decay, share['dt'])
+ if x is not None:
+ self.add_current(x)
+ return self.coeff * (self.g_decay - self.g_rise)
+
+ def add_current(self, inp):
+ self.g_rise += inp
+ self.g_decay += inp
+
+ def return_info(self):
+ return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros)
+
+
+DualExponV2.__doc__ = DualExponV2.__doc__ % (pneu_doc,)
+
+
class Alpha(DualExpon):
r"""Alpha synapse model.
diff --git a/brainpy/dyn/synapses.py b/brainpy/dyn/synapses.py
index 77ab86632..785e3f967 100644
--- a/brainpy/dyn/synapses.py
+++ b/brainpy/dyn/synapses.py
@@ -3,6 +3,7 @@
Delta,
Expon,
DualExpon,
+ DualExponV2,
NMDA,
STD,
STP,
From 8ce45b382c4d2320c33616e5b02a49742bdc71bc Mon Sep 17 00:00:00 2001
From: chaoming
Date: Thu, 13 Jul 2023 12:34:01 +0800
Subject: [PATCH 033/326] Enable test when pull requests
---
.github/workflows/CI-models.yml | 3 +++
.github/workflows/CI.yml | 3 +++
2 files changed, 6 insertions(+)
diff --git a/.github/workflows/CI-models.yml b/.github/workflows/CI-models.yml
index 2fef6aad2..1b416ccc4 100644
--- a/.github/workflows/CI-models.yml
+++ b/.github/workflows/CI-models.yml
@@ -4,6 +4,9 @@ on:
push:
branches:
- '**' # matches every branch
+ pull_request:
+ branches:
+ - '**' # matches every branch
#
diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 801013f8b..a1ed29125 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -7,6 +7,9 @@ on:
push:
branches:
- '**' # matches every branch
+ pull_request:
+ branches:
+ - '**' # matches every branch
#on:
# push:
From 4a6add60d7703adc1f392a8011f22b925b4a9577 Mon Sep 17 00:00:00 2001
From: GYF <1337838189@qq.com>
Date: Thu, 13 Jul 2023 14:50:22 +0800
Subject: [PATCH 034/326] Add random.seed
---
brainpy/_src/dnn/tests/test_activation.py | 27 +++++++++++++++++++
brainpy/_src/dnn/tests/test_conv_layers.py | 23 ++++++++++++----
brainpy/_src/dnn/tests/test_function.py | 5 +++-
brainpy/_src/dnn/tests/test_linear.py | 13 +++++++++
brainpy/_src/dnn/tests/test_normalization.py | 6 +++++
brainpy/_src/dnn/tests/test_pooling_layers.py | 17 ++++++++++++
6 files changed, 85 insertions(+), 6 deletions(-)
diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py
index 2915b0f35..ab7e20300 100644
--- a/brainpy/_src/dnn/tests/test_activation.py
+++ b/brainpy/_src/dnn/tests/test_activation.py
@@ -22,6 +22,7 @@ def test_Threshold(self, inplace):
inplace=[True, False]
)
def test_ReLU(self, inplace):
+ bm.random.seed()
ReLU_layer = bp.dnn.ReLU(inplace)
input = bm.random.randn(2)
if inplace == True:
@@ -33,6 +34,7 @@ def test_ReLU(self, inplace):
inplace=[True, False]
)
def test_RReLU(self, inplace):
+ bm.random.seed()
RReLU_layer = bp.dnn.RReLU(lower=0, upper=1, inplace=inplace)
input = bm.random.randn(2)
if inplace == True:
@@ -44,6 +46,7 @@ def test_RReLU(self, inplace):
inplace=[True, False]
)
def test_Hardtanh(self, inplace):
+ bm.random.seed()
Hardtanh_layer = bp.dnn.Hardtanh(min_val=0, max_val=1, inplace=inplace)
input = bm.random.randn(2)
if inplace == True:
@@ -55,6 +58,7 @@ def test_Hardtanh(self, inplace):
inplace=[True, False]
)
def test_ReLU6(self, inplace):
+ bm.random.seed()
ReLU6_layer = bp.dnn.ReLU6(inplace=inplace)
input = bm.random.randn(2)
if inplace == True:
@@ -63,6 +67,7 @@ def test_ReLU6(self, inplace):
output = ReLU6_layer(input)
def test_Sigmoid(self):
+ bm.random.seed()
Sigmoid_layer = bp.dnn.Sigmoid()
input = bm.random.randn(2)
output = Sigmoid_layer(input)
@@ -71,6 +76,7 @@ def test_Sigmoid(self):
inplace=[True, False]
)
def test_Hardsigmoid(self, inplace):
+ bm.random.seed()
Hardsigmoid_layer = bp.dnn.Hardsigmoid(inplace=inplace)
input = bm.random.randn(2)
if inplace == True:
@@ -79,6 +85,7 @@ def test_Hardsigmoid(self, inplace):
output = Hardsigmoid_layer(input)
def test_Tanh(self):
+ bm.random.seed()
Tanh_layer = bp.dnn.Tanh()
input = bm.random.randn(2)
output = Tanh_layer(input)
@@ -87,6 +94,7 @@ def test_Tanh(self):
inplace=[True, False]
)
def test_SiLU(self, inplace):
+ bm.random.seed()
SiLU_layer = bp.dnn.SiLU(inplace=inplace)
input = bm.random.randn(2)
if inplace == True:
@@ -98,6 +106,7 @@ def test_SiLU(self, inplace):
inplace=[True, False]
)
def test_Mish(self, inplace):
+ bm.random.seed()
Mish_layer = bp.dnn.Mish(inplace=inplace)
input = bm.random.randn(2)
if inplace == True:
@@ -109,6 +118,7 @@ def test_Mish(self, inplace):
inplace=[True, False]
)
def test_Hardswish(self, inplace):
+ bm.random.seed()
Hardswish_layer = bp.dnn.Hardswish(inplace=inplace)
input = bm.random.randn(2)
if inplace == True:
@@ -120,6 +130,7 @@ def test_Hardswish(self, inplace):
inplace=[True, False]
)
def test_ELU(self, inplace):
+ bm.random.seed()
ELU_layer = bp.dnn.ELU(alpha=0.5, inplace=inplace)
input = bm.random.randn(2)
if inplace == True:
@@ -131,6 +142,7 @@ def test_ELU(self, inplace):
inplace=[True, False]
)
def test_CELU(self, inplace):
+ bm.random.seed()
CELU_layer = bp.dnn.CELU(alpha=0.5, inplace=inplace)
input = bm.random.randn(2)
if inplace == True:
@@ -142,6 +154,7 @@ def test_CELU(self, inplace):
inplace=[True, False]
)
def test_SELU(self, inplace):
+ bm.random.seed()
SELU_layer = bp.dnn.SELU(inplace=inplace)
input = bm.random.randn(2)
if inplace == True:
@@ -150,6 +163,7 @@ def test_SELU(self, inplace):
output = SELU_layer(input)
def test_GLU(self):
+ bm.random.seed()
GLU_layer = bp.dnn.GLU()
input = bm.random.randn(4, 2)
output = GLU_layer(input)
@@ -158,11 +172,13 @@ def test_GLU(self):
approximate=['tanh', 'none']
)
def test_GELU(self, approximate):
+ bm.random.seed()
GELU_layer = bp.dnn.GELU()
input = bm.random.randn(2)
output = GELU_layer(input)
def test_Hardshrink(self):
+ bm.random.seed()
Hardshrink_layer = bp.dnn.Hardshrink(lambd=1)
input = bm.random.randn(2)
output = Hardshrink_layer(input)
@@ -171,6 +187,7 @@ def test_Hardshrink(self):
inplace=[True, False]
)
def test_LeakyReLU(self, inplace):
+ bm.random.seed()
LeakyReLU_layer = bp.dnn.LeakyReLU(inplace=inplace)
input = bm.random.randn(2)
if inplace == True:
@@ -179,6 +196,7 @@ def test_LeakyReLU(self, inplace):
output = LeakyReLU_layer(input)
def test_LogSigmoid(self):
+ bm.random.seed()
LogSigmoid_layer = bp.dnn.LogSigmoid()
input = bm.random.randn(2)
output = LogSigmoid_layer(input)
@@ -188,46 +206,55 @@ def test_LogSigmoid(self):
threshold=[20, 21, 22]
)
def test_Softplus(self, beta, threshold):
+ bm.random.seed()
Softplus_layer = bp.dnn.Softplus(beta=beta, threshold=threshold)
input = bm.random.randn(2)
output = Softplus_layer(input)
def test_Softshrink(self):
+ bm.random.seed()
Softshrink_layer = bp.dnn.Softshrink(lambd=1)
input = bm.random.randn(2)
output = Softshrink_layer(input)
def test_PReLU(self):
+ bm.random.seed()
PReLU_layer = bp.dnn.PReLU(num_parameters=2, init=0.5)
input = bm.random.randn(2)
output = PReLU_layer(input)
def test_Softsign(self):
+ bm.random.seed()
Softsign_layer = bp.dnn.Softsign()
input = bm.random.randn(2)
output = Softsign_layer(input)
def test_Tanhshrink(self):
+ bm.random.seed()
Tanhshrink_layer = bp.dnn.Tanhshrink()
input = bm.random.randn(2)
output = Tanhshrink_layer(input)
def test_Softmin(self):
+ bm.random.seed()
Softmin_layer = bp.dnn.Softmin(dim=2)
input = bm.random.randn(2, 3, 4)
output = Softmin_layer(input)
def test_Softmax(self):
+ bm.random.seed()
Softmax_layer = bp.dnn.Softmax(dim=2)
input = bm.random.randn(2, 3, 4)
output = Softmax_layer(input)
def test_Softmax2d(self):
+ bm.random.seed()
Softmax2d_layer = bp.dnn.Softmax2d()
input = bm.random.randn(2, 3, 12, 13)
output = Softmax2d_layer(input)
def test_LogSoftmax(self):
+ bm.random.seed()
LogSoftmax_layer = bp.dnn.LogSoftmax(dim=2)
input = bm.random.randn(2, 3, 4)
output = LogSoftmax_layer(input)
diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py
index 828b06496..b8ebc0adb 100644
--- a/brainpy/_src/dnn/tests/test_conv_layers.py
+++ b/brainpy/_src/dnn/tests/test_conv_layers.py
@@ -4,12 +4,13 @@
from absl.testing import absltest
import jax.numpy as jnp
import brainpy.math as bm
-
+from absl.testing import parameterized
import brainpy as bp
-class TestConv(bp.testing.UnitTestCase):
+class TestConv(parameterized.TestCase):
def test_Conv2D_img(self):
+ bm.random.seed()
img = jnp.zeros((2, 200, 198, 4))
for k in range(4):
x = 30 + 60 * k
@@ -28,6 +29,7 @@ def test_Conv2D_img(self):
# plt.show()
def test_conv1D(self):
+ bm.random.seed()
with bp.math.training_environment():
model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,))
@@ -41,6 +43,7 @@ def test_conv1D(self):
# plt.show()
def test_conv2D(self):
+ bm.random.seed()
with bp.math.training_environment():
model = bp.layers.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3))
@@ -54,6 +57,7 @@ def test_conv2D(self):
# plt.show()
def test_conv3D(self):
+ bm.random.seed()
with bp.math.training_environment():
model = bp.layers.Conv3d(in_channels=3, out_channels=32, kernel_size=(3, 3, 3))
input = bp.math.ones((2, 5, 5, 5, 3))
@@ -61,8 +65,9 @@ def test_conv3D(self):
print("out shape: ", out.shape)
-class TestConvTranspose1d(bp.testing.UnitTestCase):
+class TestConvTranspose1d(parameterized.TestCase):
def test_conv_transpose(self):
+ bm.random.seed()
x = bm.ones((1, 8, 3))
for use_bias in [True, False]:
conv_transpose_module = bp.layers.ConvTranspose1d(
@@ -92,6 +97,7 @@ def test_conv_transpose(self):
self.assertTrue(bm.allclose(y, correct_ans))
def test_single_input_masked_conv_transpose(self):
+ bm.random.seed()
x = jnp.ones((1, 8, 3))
m = jnp.tril(jnp.ones((3, 3, 4)))
conv_transpose_module = bp.layers.ConvTranspose1d(
@@ -120,6 +126,7 @@ def test_single_input_masked_conv_transpose(self):
self.assertTrue(bm.allclose(y, correct_ans))
def test_computation_padding_same(self):
+ bm.random.seed()
data = jnp.ones([1, 3, 1])
for use_bias in [True, False]:
net = bp.layers.ConvTranspose1d(
@@ -141,8 +148,9 @@ def test_computation_padding_same(self):
self.assertTrue(bm.allclose(out, expected_out, rtol=1e-5))
-class TestConvTranspose2d(bp.testing.UnitTestCase):
+class TestConvTranspose2d(parameterized.TestCase):
def test_conv_transpose(self):
+ bm.random.seed()
x = bm.ones((1, 8, 8, 3))
for use_bias in [True, False]:
conv_transpose_module = bp.layers.ConvTranspose2d(
@@ -159,6 +167,7 @@ def test_conv_transpose(self):
print(y.shape)
def test_single_input_masked_conv_transpose(self):
+ bm.random.seed()
x = jnp.ones((1, 8, 8, 3))
m = jnp.tril(jnp.ones((3, 3, 3, 4)))
conv_transpose_module = bp.layers.ConvTranspose2d(
@@ -174,6 +183,7 @@ def test_single_input_masked_conv_transpose(self):
print(y.shape)
def test_computation_padding_same(self):
+ bm.random.seed()
x = bm.ones((1, 8, 8, 3))
for use_bias in [True, False]:
conv_transpose_module = bp.layers.ConvTranspose2d(
@@ -191,8 +201,9 @@ def test_computation_padding_same(self):
print(y.shape)
-class TestConvTranspose3d(bp.testing.UnitTestCase):
+class TestConvTranspose3d(parameterized.TestCase):
def test_conv_transpose(self):
+ bm.random.seed()
x = bm.ones((1, 8, 8, 8, 3))
for use_bias in [True, False]:
conv_transpose_module = bp.layers.ConvTranspose3d(
@@ -208,6 +219,7 @@ def test_conv_transpose(self):
print(y.shape)
def test_single_input_masked_conv_transpose(self):
+ bm.random.seed()
x = jnp.ones((1, 8, 8, 8, 3))
m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
conv_transpose_module = bp.layers.ConvTranspose3d(
@@ -223,6 +235,7 @@ def test_single_input_masked_conv_transpose(self):
print(y.shape)
def test_computation_padding_same(self):
+ bm.random.seed()
x = bm.ones((1, 8, 8, 8, 3))
for use_bias in [True, False]:
conv_transpose_module = bp.layers.ConvTranspose3d(
diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py
index b51efe16f..90dcae17b 100644
--- a/brainpy/_src/dnn/tests/test_function.py
+++ b/brainpy/_src/dnn/tests/test_function.py
@@ -5,12 +5,14 @@
import jax.numpy as jnp
import brainpy.math as bm
from absl.testing import absltest
+from absl.testing import parameterized
import brainpy as bp
-class TestFunction(bp.testing.UnitTestCase):
+class TestFunction(parameterized.TestCase):
def test_flatten_batching_mode(self):
+ bm.random.seed()
layer = bp.dnn.Flatten(mode=bm.BatchingMode())
input = bm.random.randn(20, 10, 10, 6)
@@ -20,6 +22,7 @@ def test_flatten_batching_mode(self):
self.assertEqual(output.shape, expected_shape)
def test_flatten_non_batching_mode(self):
+ bm.random.seed()
layer = bp.dnn.Flatten(mode=bm.NonBatchingMode())
input = bm.random.randn(10, 10, 6)
diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py
index 5ce07d474..98214563a 100644
--- a/brainpy/_src/dnn/tests/test_linear.py
+++ b/brainpy/_src/dnn/tests/test_linear.py
@@ -16,6 +16,7 @@ def __init__(self, *args, **kwargs):
num_out=[20, 10, 5]
)
def test_Dense1(self, size, num_out):
+ bm.random.seed()
f = bp.dnn.Linear(10, num_out)
x = bm.random.random(size)
y = f(x)
@@ -27,12 +28,14 @@ def test_Dense1(self, size, num_out):
(5, 8, 10)],
)
def test_Identity(self, size):
+ bm.random.seed()
f = bp.dnn.Identity()
x = bm.random.random(size)
y = f(x)
self.assertTrue(y.shape == size)
def test_AllToAll1(self):
+ bm.random.seed()
with bm.environment(mode=bm.BatchingMode()):
f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True)
x = bm.random.random((8, 10))
@@ -48,6 +51,7 @@ def test_AllToAll1(self):
self.assertTrue(bm.allclose(y, expected))
def test_OneToOne(self):
+ bm.random.seed()
with bm.environment(mode=bm.BatchingMode()):
f = bp.dnn.OneToOne(10, weight=.1)
x = bm.random.random((8, 10))
@@ -70,6 +74,7 @@ def test_OneToOne(self):
]
)
def test_MaskedLinear(self, conn):
+ bm.random.seed()
bm.random.DEFAULT.seed(123)
f = bp.dnn.MaskedLinear(conn, weight=bp.init.XavierNormal(seed=123))
x = bm.random.random((16, 100))
@@ -84,6 +89,7 @@ def test_MaskedLinear(self, conn):
]
)
def test_CSRLinear(self, conn):
+ bm.random.seed()
f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal())
x = bm.random.random((16, 100))
y = f(x)
@@ -102,6 +108,7 @@ def test_CSRLinear(self, conn):
]
)
def test_EventCSRLinear(self,conn):
+ bm.random.seed()
f=bp.layers.EventCSRLinear(conn,weight=bp.init.Normal())
x = bm.random.random((16, 100))
y = f(x)
@@ -117,6 +124,7 @@ def test_EventCSRLinear(self,conn):
shape=[(), (10,), (10, 20), (10, 20, 25)]
)
def test_JitFPHomoLinear(self, prob, weight, shape):
+ bm.random.seed()
f = bp.dnn.JitFPHomoLinear(100, 200, prob, weight, seed=123)
x = bm.random.random(shape + (100,))
y = f(x)
@@ -129,6 +137,7 @@ def test_JitFPHomoLinear(self, prob, weight, shape):
shape=[(), (10,), (10, 20), (10, 20, 25)]
)
def test_JitFPUniformLinear(self, prob, w_low, w_high, shape):
+ bm.random.seed()
f = bp.dnn.JitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123)
x = bm.random.random(shape + (100,))
y = f(x)
@@ -141,6 +150,7 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape):
shape=[(), (10,), (10, 20), (10, 20, 25)]
)
def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
+ bm.random.seed()
f = bp.dnn.JitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123)
x = bm.random.random(shape + (100,))
y = f(x)
@@ -152,6 +162,7 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
shape=[(), (10,), (10, 20), (10, 20, 25)]
)
def test_EventJitFPHomoLinear(self, prob, weight, shape):
+ bm.random.seed()
f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123)
y = f(bm.random.random(shape + (100,)) < 0.1)
self.assertTrue(y.shape == shape + (200,))
@@ -166,6 +177,7 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape):
shape=[(), (10,), (10, 20), (10, 20, 25)]
)
def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape):
+ bm.random.seed()
f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123)
y = f(bm.random.random(shape + (100,)) < 0.1)
self.assertTrue(y.shape == shape + (200,))
@@ -180,6 +192,7 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape):
shape=[(), (10,), (10, 20), (10, 20, 25)]
)
def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
+ bm.random.seed()
f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123)
y = f(bm.random.random(shape + (100,)) < 0.1)
self.assertTrue(y.shape == shape + (200,))
diff --git a/brainpy/_src/dnn/tests/test_normalization.py b/brainpy/_src/dnn/tests/test_normalization.py
index a93a64de0..3e4da301e 100644
--- a/brainpy/_src/dnn/tests/test_normalization.py
+++ b/brainpy/_src/dnn/tests/test_normalization.py
@@ -9,6 +9,7 @@ class Test_Normalization(parameterized.TestCase):
fit=[True, False],
)
def test_BatchNorm1d(self, fit):
+ bm.random.seed()
net = bp.dnn.BatchNorm1d(num_features=10, mode=bm.training_mode)
bp.share.save(fit=fit)
input = bm.random.randn(1, 3, 10)
@@ -18,6 +19,7 @@ def test_BatchNorm1d(self, fit):
fit=[True, False]
)
def test_BatchNorm2d(self, fit):
+ bm.random.seed()
net = bp.dnn.BatchNorm2d(10, mode=bm.training_mode)
bp.share.save(fit=fit)
input = bm.random.randn(1, 3, 4, 10)
@@ -27,6 +29,7 @@ def test_BatchNorm2d(self, fit):
fit=[True, False]
)
def test_BatchNorm3d(self, fit):
+ bm.random.seed()
net = bp.dnn.BatchNorm3d(10, mode=bm.training_mode)
bp.share.save(fit=fit)
input = bm.random.randn(1, 3, 4, 5, 10)
@@ -36,6 +39,7 @@ def test_BatchNorm3d(self, fit):
normalized_shape=(10, [5, 10])
)
def test_LayerNorm(self, normalized_shape):
+ bm.random.seed()
net = bp.dnn.LayerNorm(normalized_shape, mode=bm.training_mode)
input = bm.random.randn(20, 5, 10)
output = net(input)
@@ -44,11 +48,13 @@ def test_LayerNorm(self, normalized_shape):
num_groups=[1, 2, 3, 6]
)
def test_GroupNorm(self, num_groups):
+ bm.random.seed()
input = bm.random.randn(20, 10, 10, 6)
net = bp.dnn.GroupNorm(num_groups=num_groups, num_channels=6, mode=bm.training_mode)
output = net(input)
def test_InstanceNorm(self):
+ bm.random.seed()
input = bm.random.randn(20, 10, 10, 6)
net = bp.dnn.InstanceNorm(num_channels=6, mode=bm.training_mode)
output = net(input)
diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py
index b05932cb3..1acdf15bc 100644
--- a/brainpy/_src/dnn/tests/test_pooling_layers.py
+++ b/brainpy/_src/dnn/tests/test_pooling_layers.py
@@ -17,6 +17,7 @@ def __init__(self, *args, **kwargs):
self.rng = bm.random.default_rng(12345)
def test_maxpool(self):
+ bm.random.seed()
x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
print(jnp.arange(9).reshape(3, 3))
print(x)
@@ -31,6 +32,7 @@ def test_maxpool(self):
np.testing.assert_allclose(y, expected_y)
def test_maxpool2(self):
+ bm.random.seed()
x = self.rng.rand(10, 20, 20, 4)
with bm.training_environment():
net = bp.dnn.MaxPool((2, 2), (2, 2), channel_axis=-1)
@@ -38,6 +40,7 @@ def test_maxpool2(self):
print("out shape: ", y.shape)
def test_minpool(self):
+ bm.random.seed()
x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
shared = {'fit': False}
with bm.training_environment():
@@ -51,6 +54,7 @@ def test_minpool(self):
np.testing.assert_allclose(y, expected_y)
def test_avgpool(self):
+ bm.random.seed()
x = jnp.full((1, 3, 3, 1), 2.)
with bm.training_environment():
net = bp.dnn.AvgPool((2, 2), 1, channel_axis=-1)
@@ -59,6 +63,7 @@ def test_avgpool(self):
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))
def test_MaxPool2d_v1(self):
+ bm.random.seed()
arr = self.rng.rand(16, 32, 32, 8)
out = bp.dnn.MaxPool2d(2, 2, channel_axis=-1)(arr)
@@ -80,6 +85,7 @@ def test_MaxPool2d_v1(self):
self.assertTrue(out.shape == (16, 17, 32, 5))
def test_AvgPool2d_v1(self):
+ bm.random.seed()
arr = self.rng.rand(16, 32, 32, 8)
out = bp.dnn.AvgPool2d(2, 2, channel_axis=-1)(arr)
@@ -106,6 +112,7 @@ def test_AvgPool2d_v1(self):
for target_size in [10, 9, 8, 7, 6]
)
def test_adaptive_pool1d(self, target_size):
+ bm.random.seed()
from brainpy._src.dnn.pooling import _adaptive_pool1d
arr = self.rng.rand(100)
@@ -120,6 +127,7 @@ def test_adaptive_pool1d(self, target_size):
self.assertTrue(out.shape == (target_size,))
def test_AdaptiveAvgPool2d_v1(self):
+ bm.random.seed()
input = self.rng.randn(64, 8, 9)
output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
@@ -138,6 +146,7 @@ def test_AdaptiveAvgPool2d_v1(self):
self.assertTrue(output.shape == (64, 2, 3))
def test_AdaptiveAvgPool2d_v2(self):
+ bm.random.seed()
input = self.rng.randn(128, 64, 32, 16)
output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
@@ -154,12 +163,14 @@ def test_AdaptiveAvgPool2d_v2(self):
print()
def test_AdaptiveAvgPool3d_v1(self):
+ bm.random.seed()
input = bm.random.randn(10, 128, 64, 32)
net = bp.dnn.AdaptiveAvgPool3d(target_shape=[6, 5, 3], channel_axis=0, mode=bm.nonbatching_mode)
output = net(input)
self.assertTrue(output.shape == (10, 6, 5, 3))
def test_AdaptiveAvgPool3d_v2(self):
+ bm.random.seed()
input = bm.random.randn(10, 20, 128, 64, 32)
net = bp.dnn.AdaptiveAvgPool3d(target_shape=[6, 5, 3], mode=bm.batching_mode)
output = net(input)
@@ -169,6 +180,7 @@ def test_AdaptiveAvgPool3d_v2(self):
axis=(-1, 0, 1)
)
def test_AdaptiveMaxPool1d_v1(self, axis):
+ bm.random.seed()
input = bm.random.randn(32, 16)
net = bp.dnn.AdaptiveMaxPool1d(target_shape=4, channel_axis=axis)
output = net(input)
@@ -177,6 +189,7 @@ def test_AdaptiveMaxPool1d_v1(self, axis):
axis=(-1, 0, 1, 2)
)
def test_AdaptiveMaxPool1d_v2(self, axis):
+ bm.random.seed()
input = bm.random.randn(2, 32, 16)
net = bp.dnn.AdaptiveMaxPool1d(target_shape=4, channel_axis=axis)
output = net(input)
@@ -185,6 +198,7 @@ def test_AdaptiveMaxPool1d_v2(self, axis):
axis=(-1, 0, 1, 2)
)
def test_AdaptiveMaxPool2d_v1(self, axis):
+ bm.random.seed()
input = bm.random.randn(32, 16, 12)
net = bp.dnn.AdaptiveAvgPool2d(target_shape=[5, 4], channel_axis=axis)
output = net(input)
@@ -193,6 +207,7 @@ def test_AdaptiveMaxPool2d_v1(self, axis):
axis=(-1, 0, 1, 2, 3)
)
def test_AdaptiveMaxPool2d_v2(self, axis):
+ bm.random.seed()
input = bm.random.randn(2, 32, 16, 12)
net = bp.dnn.AdaptiveAvgPool2d(target_shape=[5, 4], channel_axis=axis)
# output = net(input)
@@ -201,6 +216,7 @@ def test_AdaptiveMaxPool2d_v2(self, axis):
axis=(-1, 0, 1, 2, 3)
)
def test_AdaptiveMaxPool3d_v1(self, axis):
+ bm.random.seed()
input = bm.random.randn(2, 128, 64, 32)
net = bp.dnn.AdaptiveMaxPool3d(target_shape=[6, 5, 4], channel_axis=axis)
output = net(input)
@@ -210,6 +226,7 @@ def test_AdaptiveMaxPool3d_v1(self, axis):
axis=(-1, 0, 1, 2, 3, 4)
)
def test_AdaptiveMaxPool3d_v1(self, axis):
+ bm.random.seed()
input = bm.random.randn(2, 128, 64, 32, 16)
net = bp.dnn.AdaptiveMaxPool3d(target_shape=[6, 5, 4], channel_axis=axis)
output = net(input)
From c4dff140057209f8173e0706c2aec0370c697dea Mon Sep 17 00:00:00 2001
From: GYF <1337838189@qq.com>
Date: Thu, 13 Jul 2023 15:54:02 +0800
Subject: [PATCH 035/326] update tests
---
.github/workflows/CI.yml | 4 ++--
brainpy/_src/dnn/tests/test_pooling_layers.py | 14 ++++++--------
2 files changed, 8 insertions(+), 10 deletions(-)
diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 801013f8b..845a4ac70 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -166,8 +166,8 @@ jobs:
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install numpy>=1.21.0
- python -m pip install "jaxlib==0.4.10" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
- python -m pip install jax==0.4.10
+ python -m pip install "jaxlib==0.4.11" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
+ python -m pip install jax==0.4.11
python -m pip install -r requirements-dev.txt
python -m pip install tqdm brainpylib
pip uninstall brainpy -y
diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py
index 1acdf15bc..64a7c881d 100644
--- a/brainpy/_src/dnn/tests/test_pooling_layers.py
+++ b/brainpy/_src/dnn/tests/test_pooling_layers.py
@@ -14,8 +14,6 @@ class TestPool(parameterized.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.rng = bm.random.default_rng(12345)
-
def test_maxpool(self):
bm.random.seed()
x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
@@ -33,7 +31,7 @@ def test_maxpool(self):
def test_maxpool2(self):
bm.random.seed()
- x = self.rng.rand(10, 20, 20, 4)
+ x = bm.random.rand(10, 20, 20, 4)
with bm.training_environment():
net = bp.dnn.MaxPool((2, 2), (2, 2), channel_axis=-1)
y = net(x)
@@ -64,7 +62,7 @@ def test_avgpool(self):
def test_MaxPool2d_v1(self):
bm.random.seed()
- arr = self.rng.rand(16, 32, 32, 8)
+ arr = bm.random.rand(16, 32, 32, 8)
out = bp.dnn.MaxPool2d(2, 2, channel_axis=-1)(arr)
self.assertTrue(out.shape == (16, 16, 16, 8))
@@ -86,7 +84,7 @@ def test_MaxPool2d_v1(self):
def test_AvgPool2d_v1(self):
bm.random.seed()
- arr = self.rng.rand(16, 32, 32, 8)
+ arr = bm.random.rand(16, 32, 32, 8)
out = bp.dnn.AvgPool2d(2, 2, channel_axis=-1)(arr)
self.assertTrue(out.shape == (16, 16, 16, 8))
@@ -115,7 +113,7 @@ def test_adaptive_pool1d(self, target_size):
bm.random.seed()
from brainpy._src.dnn.pooling import _adaptive_pool1d
- arr = self.rng.rand(100)
+ arr = bm.random.rand(100)
op = jax.numpy.mean
out = _adaptive_pool1d(arr, target_size, op)
@@ -128,7 +126,7 @@ def test_adaptive_pool1d(self, target_size):
def test_AdaptiveAvgPool2d_v1(self):
bm.random.seed()
- input = self.rng.randn(64, 8, 9)
+ input = bm.random.randn(64, 8, 9)
output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
self.assertTrue(output.shape == (64, 5, 7))
@@ -147,7 +145,7 @@ def test_AdaptiveAvgPool2d_v1(self):
def test_AdaptiveAvgPool2d_v2(self):
bm.random.seed()
- input = self.rng.randn(128, 64, 32, 16)
+ input = bm.random.randn(128, 64, 32, 16)
output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
self.assertTrue(output.shape == (128, 64, 5, 7))
From 979c89ce21c5fecddb89210a0e0c86e3aed48abb Mon Sep 17 00:00:00 2001
From: chaoming
Date: Thu, 13 Jul 2023 15:57:23 +0800
Subject: [PATCH 036/326] delete brainpy.testing, since it causes unexpected
bugs
---
brainpy/__init__.py | 1 -
.../math/object_transform/tests/test_base.py | 16 ++++++++++++----
.../object_transform/tests/test_controls.py | 9 +++++++--
.../math/object_transform/tests/test_tools.py | 8 +++++++-
brainpy/_src/testing/__init__.py | 0
brainpy/_src/testing/base.py | 17 -----------------
brainpy/testing.py | 1 -
tests/simulation/test_net_rate_SL.py | 3 ++-
tests/simulation/test_neu_HH.py | 5 +++--
9 files changed, 31 insertions(+), 29 deletions(-)
delete mode 100644 brainpy/_src/testing/__init__.py
delete mode 100644 brainpy/_src/testing/base.py
delete mode 100644 brainpy/testing.py
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index 90edaca3d..93db462d5 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -102,7 +102,6 @@
# Part: Others #
# ---------------- #
-from brainpy import testing
from brainpy._src.visualization import (visualize as visualize)
diff --git a/brainpy/_src/math/object_transform/tests/test_base.py b/brainpy/_src/math/object_transform/tests/test_base.py
index 9435cf56f..2d640b3b5 100644
--- a/brainpy/_src/math/object_transform/tests/test_base.py
+++ b/brainpy/_src/math/object_transform/tests/test_base.py
@@ -79,8 +79,10 @@ def __init__(self):
self.assertTrue(len(net.vars(level=3, include_self=False)) == (2 + 4 + 8) * 2)
-class TestNodeList(bp.testing.UnitTestCase):
+class TestNodeList(unittest.TestCase):
def test_NodeList_1(self):
+ bm.random.seed()
+
class Object(bp.DynamicalSystem):
def __init__(self):
super().__init__()
@@ -119,8 +121,10 @@ def update(self, x):
# print(jax.tree_util.tree_structure(obj))
-class TestNodeDict(bp.testing.UnitTestCase):
+class TestNodeDict(unittest.TestCase):
def test_NodeDict_1(self):
+ bm.random.seed()
+
class Object(bp.DynamicalSystem):
def __init__(self):
super().__init__()
@@ -165,8 +169,10 @@ def update(self, x):
# print(jax.tree_util.tree_structure(obj))
-class TestVarList(bp.testing.UnitTestCase):
+class TestVarList(unittest.TestCase):
def test_ListVar_1(self):
+ bm.random.seed()
+
class Object(bp.DynamicalSystem):
def __init__(self):
super().__init__()
@@ -194,8 +200,10 @@ def f2():
self.assertTrue(bm.allclose(obj.vs[2], bm.ones(10) * 11.))
-class TestVarDict(bp.testing.UnitTestCase):
+class TestVarDict(unittest.TestCase):
def test_DictVar_1(self):
+ bm.random.seed()
+
class Object(bp.DynamicalSystem):
def __init__(self):
super().__init__()
diff --git a/brainpy/_src/math/object_transform/tests/test_controls.py b/brainpy/_src/math/object_transform/tests/test_controls.py
index 62e689535..359f03c74 100644
--- a/brainpy/_src/math/object_transform/tests/test_controls.py
+++ b/brainpy/_src/math/object_transform/tests/test_controls.py
@@ -188,8 +188,10 @@ def f2():
self.assertTrue(f2().size == 200)
-class TestWhile(bp.testing.UnitTestCase):
+class TestWhile(unittest.TestCase):
def test1(self):
+ bm.random.seed()
+
a = bm.Variable(bm.zeros(1))
b = bm.Variable(bm.ones(1))
@@ -206,6 +208,8 @@ def body(x, y):
print(res)
def test3(self):
+ bm.random.seed()
+
a = bm.Variable(bm.zeros(1))
b = bm.Variable(bm.ones(1))
@@ -224,8 +228,9 @@ def body(x, y):
print(a)
print(b)
-
def test2(self):
+ bm.random.seed()
+
a = bm.Variable(bm.zeros(1))
b = bm.Variable(bm.ones(1))
diff --git a/brainpy/_src/math/object_transform/tests/test_tools.py b/brainpy/_src/math/object_transform/tests/test_tools.py
index 69781d624..22357c0b2 100644
--- a/brainpy/_src/math/object_transform/tests/test_tools.py
+++ b/brainpy/_src/math/object_transform/tests/test_tools.py
@@ -1,11 +1,13 @@
import brainpy as bp
import brainpy.math as bm
import jax
+import unittest
from brainpy._src.math.object_transform._tools import evaluate_dyn_vars_with_cache
-class TestTool(bp.testing.UnitTestCase):
+class TestTool(unittest.TestCase):
def test1(self):
+ bm.random.seed()
neu = bp.neurons.HH((5,))
call_num = [0]
@@ -22,6 +24,7 @@ def f():
self.assertTrue(isinstance(v.value, jax.Array))
def test_cache1(self):
+ bm.random.seed()
neu = bp.neurons.HH((5,))
call_num = [0]
@@ -44,6 +47,7 @@ def f():
self.assertTrue(isinstance(v.value, jax.Array))
def test_nested_evaluate(self):
+ bm.random.seed()
neu = bp.neurons.HH((5,))
a = bm.Variable(bm.ones(1))
@@ -64,6 +68,7 @@ def f2():
self.assertTrue(isinstance(a.value, jax.Array))
def test_cache2(self):
+ bm.random.seed()
neu = bp.neurons.HH((5,))
a = bm.Variable(bm.ones(1))
call_num = [0]
@@ -90,6 +95,7 @@ def f2():
self.assertTrue(call_num[0] == 1)
def test_cache3(self):
+ bm.random.seed()
call_num = [0]
class Model(bp.DynamicalSystem):
diff --git a/brainpy/_src/testing/__init__.py b/brainpy/_src/testing/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/brainpy/_src/testing/base.py b/brainpy/_src/testing/base.py
deleted file mode 100644
index 6f7f94c7a..000000000
--- a/brainpy/_src/testing/base.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import unittest
-import brainpy.math as bm
-import numpy as np
-
-try:
- from absl.testing import parameterized
-except ImportError:
- pass
-
-
-class UnitTestCase(unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- bm.random.seed(np.random.randint(0, 100000))
- self.rng = bm.random.RandomState(np.random.randint(0, 100000))
-
-
diff --git a/brainpy/testing.py b/brainpy/testing.py
deleted file mode 100644
index f06131a3b..000000000
--- a/brainpy/testing.py
+++ /dev/null
@@ -1 +0,0 @@
-from brainpy._src.testing.base import UnitTestCase
diff --git a/tests/simulation/test_net_rate_SL.py b/tests/simulation/test_net_rate_SL.py
index fad1dd6ed..05d81c415 100644
--- a/tests/simulation/test_net_rate_SL.py
+++ b/tests/simulation/test_net_rate_SL.py
@@ -25,8 +25,9 @@ def __init__(self, noise=0.14):
)
-class TestSL(bp.testing.UnitTestCase):
+class TestSL(unittest.TestCase):
def test1(self):
+ bm.random.seed()
net = Network()
runner = bp.DSRunner(net, monitors=['sl.x'])
runner.run(6e3 if show else 1e2)
diff --git a/tests/simulation/test_neu_HH.py b/tests/simulation/test_neu_HH.py
index 2e80cabb5..ad0e51360 100644
--- a/tests/simulation/test_neu_HH.py
+++ b/tests/simulation/test_neu_HH.py
@@ -1,10 +1,11 @@
import brainpy as bp
import brainpy.math as bm
+import unittest
show = False
-class HH(bp.CondNeuGroup):
+class HH(bp.dyn.CondNeuGroup):
def __init__(self, size):
super(HH, self).__init__(size)
self.INa = bp.channels.INa_HH1952(size, )
@@ -89,7 +90,7 @@ def update(self, x=None):
return dV_grad
-class TestHH(bp.testing.UnitTestCase):
+class TestHH(unittest.TestCase):
def test1(self):
bm.random.seed()
hh = HH(1)
From bebe242909258534b13b0a50f966f778fcdbcd62 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Thu, 13 Jul 2023 20:25:52 +0800
Subject: [PATCH 037/326] fix test bugs
---
brainpy/_src/dnn/tests/test_activation.py | 535 +++++++++---------
brainpy/_src/dnn/tests/test_conv_layers.py | 452 ++++++++-------
brainpy/_src/dnn/tests/test_function.py | 32 +-
brainpy/_src/dnn/tests/test_linear.py | 13 +
brainpy/_src/dnn/tests/test_normalization.py | 115 ++--
brainpy/_src/dnn/tests/test_pooling_layers.py | 453 ++++++++-------
6 files changed, 841 insertions(+), 759 deletions(-)
diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py
index ab7e20300..30bb8e032 100644
--- a/brainpy/_src/dnn/tests/test_activation.py
+++ b/brainpy/_src/dnn/tests/test_activation.py
@@ -2,263 +2,292 @@
from absl.testing import parameterized
from absl.testing import absltest
import brainpy as bp
+import brainpy.math as bm
class Test_Activation(parameterized.TestCase):
- @parameterized.product(
- inplace=[True, False]
- )
- def test_Threshold(self, inplace):
- bm.random.seed()
- threshold_layer = bp.dnn.Threshold(5, 20, inplace)
- input = bm.random.randn(2)
- if inplace == True:
- threshold_layer(input)
- elif inplace == False:
- output = threshold_layer(input)
-
- @parameterized.product(
- inplace=[True, False]
- )
- def test_ReLU(self, inplace):
- bm.random.seed()
- ReLU_layer = bp.dnn.ReLU(inplace)
- input = bm.random.randn(2)
- if inplace == True:
- ReLU_layer(input)
- elif inplace == False:
- output = ReLU_layer(input)
-
- @parameterized.product(
- inplace=[True, False]
- )
- def test_RReLU(self, inplace):
- bm.random.seed()
- RReLU_layer = bp.dnn.RReLU(lower=0, upper=1, inplace=inplace)
- input = bm.random.randn(2)
- if inplace == True:
- RReLU_layer(input)
- elif inplace == False:
- output = RReLU_layer(input)
-
- @parameterized.product(
- inplace=[True, False]
- )
- def test_Hardtanh(self, inplace):
- bm.random.seed()
- Hardtanh_layer = bp.dnn.Hardtanh(min_val=0, max_val=1, inplace=inplace)
- input = bm.random.randn(2)
- if inplace == True:
- Hardtanh_layer(input)
- elif inplace == False:
- output = Hardtanh_layer(input)
-
- @parameterized.product(
- inplace=[True, False]
- )
- def test_ReLU6(self, inplace):
- bm.random.seed()
- ReLU6_layer = bp.dnn.ReLU6(inplace=inplace)
- input = bm.random.randn(2)
- if inplace == True:
- ReLU6_layer(input)
- elif inplace == False:
- output = ReLU6_layer(input)
-
- def test_Sigmoid(self):
- bm.random.seed()
- Sigmoid_layer = bp.dnn.Sigmoid()
- input = bm.random.randn(2)
- output = Sigmoid_layer(input)
-
- @parameterized.product(
- inplace=[True, False]
- )
- def test_Hardsigmoid(self, inplace):
- bm.random.seed()
- Hardsigmoid_layer = bp.dnn.Hardsigmoid(inplace=inplace)
- input = bm.random.randn(2)
- if inplace == True:
- Hardsigmoid_layer(input)
- elif inplace == False:
- output = Hardsigmoid_layer(input)
-
- def test_Tanh(self):
- bm.random.seed()
- Tanh_layer = bp.dnn.Tanh()
- input = bm.random.randn(2)
- output = Tanh_layer(input)
-
- @parameterized.product(
- inplace=[True, False]
- )
- def test_SiLU(self, inplace):
- bm.random.seed()
- SiLU_layer = bp.dnn.SiLU(inplace=inplace)
- input = bm.random.randn(2)
- if inplace == True:
- SiLU_layer(input)
- elif inplace == False:
- output = SiLU_layer(input)
-
- @parameterized.product(
- inplace=[True, False]
- )
- def test_Mish(self, inplace):
- bm.random.seed()
- Mish_layer = bp.dnn.Mish(inplace=inplace)
- input = bm.random.randn(2)
- if inplace == True:
- Mish_layer(input)
- elif inplace == False:
- output = Mish_layer(input)
-
- @parameterized.product(
- inplace=[True, False]
- )
- def test_Hardswish(self, inplace):
- bm.random.seed()
- Hardswish_layer = bp.dnn.Hardswish(inplace=inplace)
- input = bm.random.randn(2)
- if inplace == True:
- Hardswish_layer(input)
- elif inplace == False:
- output = Hardswish_layer(input)
-
- @parameterized.product(
- inplace=[True, False]
- )
- def test_ELU(self, inplace):
- bm.random.seed()
- ELU_layer = bp.dnn.ELU(alpha=0.5, inplace=inplace)
- input = bm.random.randn(2)
- if inplace == True:
- ELU_layer(input)
- elif inplace == False:
- output = ELU_layer(input)
-
- @parameterized.product(
- inplace=[True, False]
- )
- def test_CELU(self, inplace):
- bm.random.seed()
- CELU_layer = bp.dnn.CELU(alpha=0.5, inplace=inplace)
- input = bm.random.randn(2)
- if inplace == True:
- CELU_layer(input)
- elif inplace == False:
- output = CELU_layer(input)
-
- @parameterized.product(
- inplace=[True, False]
- )
- def test_SELU(self, inplace):
- bm.random.seed()
- SELU_layer = bp.dnn.SELU(inplace=inplace)
- input = bm.random.randn(2)
- if inplace == True:
- SELU_layer(input)
- elif inplace == False:
- output = SELU_layer(input)
-
- def test_GLU(self):
- bm.random.seed()
- GLU_layer = bp.dnn.GLU()
- input = bm.random.randn(4, 2)
- output = GLU_layer(input)
-
- @parameterized.product(
- approximate=['tanh', 'none']
- )
- def test_GELU(self, approximate):
- bm.random.seed()
- GELU_layer = bp.dnn.GELU()
- input = bm.random.randn(2)
- output = GELU_layer(input)
-
- def test_Hardshrink(self):
- bm.random.seed()
- Hardshrink_layer = bp.dnn.Hardshrink(lambd=1)
- input = bm.random.randn(2)
- output = Hardshrink_layer(input)
-
- @parameterized.product(
- inplace=[True, False]
- )
- def test_LeakyReLU(self, inplace):
- bm.random.seed()
- LeakyReLU_layer = bp.dnn.LeakyReLU(inplace=inplace)
- input = bm.random.randn(2)
- if inplace == True:
- LeakyReLU_layer(input)
- elif inplace == False:
- output = LeakyReLU_layer(input)
-
- def test_LogSigmoid(self):
- bm.random.seed()
- LogSigmoid_layer = bp.dnn.LogSigmoid()
- input = bm.random.randn(2)
- output = LogSigmoid_layer(input)
-
- @parameterized.product(
- beta=[1, 2, 3],
- threshold=[20, 21, 22]
- )
- def test_Softplus(self, beta, threshold):
- bm.random.seed()
- Softplus_layer = bp.dnn.Softplus(beta=beta, threshold=threshold)
- input = bm.random.randn(2)
- output = Softplus_layer(input)
-
- def test_Softshrink(self):
- bm.random.seed()
- Softshrink_layer = bp.dnn.Softshrink(lambd=1)
- input = bm.random.randn(2)
- output = Softshrink_layer(input)
-
- def test_PReLU(self):
- bm.random.seed()
- PReLU_layer = bp.dnn.PReLU(num_parameters=2, init=0.5)
- input = bm.random.randn(2)
- output = PReLU_layer(input)
-
- def test_Softsign(self):
- bm.random.seed()
- Softsign_layer = bp.dnn.Softsign()
- input = bm.random.randn(2)
- output = Softsign_layer(input)
-
- def test_Tanhshrink(self):
- bm.random.seed()
- Tanhshrink_layer = bp.dnn.Tanhshrink()
- input = bm.random.randn(2)
- output = Tanhshrink_layer(input)
-
- def test_Softmin(self):
- bm.random.seed()
- Softmin_layer = bp.dnn.Softmin(dim=2)
- input = bm.random.randn(2, 3, 4)
- output = Softmin_layer(input)
-
- def test_Softmax(self):
- bm.random.seed()
- Softmax_layer = bp.dnn.Softmax(dim=2)
- input = bm.random.randn(2, 3, 4)
- output = Softmax_layer(input)
-
- def test_Softmax2d(self):
- bm.random.seed()
- Softmax2d_layer = bp.dnn.Softmax2d()
- input = bm.random.randn(2, 3, 12, 13)
- output = Softmax2d_layer(input)
-
- def test_LogSoftmax(self):
- bm.random.seed()
- LogSoftmax_layer = bp.dnn.LogSoftmax(dim=2)
- input = bm.random.randn(2, 3, 4)
- output = LogSoftmax_layer(input)
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_Threshold(self, inplace):
+ bm.random.seed()
+ threshold_layer = bp.dnn.Threshold(5, 20, inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ threshold_layer(input)
+ elif inplace == False:
+ output = threshold_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_ReLU(self, inplace):
+ bm.random.seed()
+ ReLU_layer = bp.dnn.ReLU(inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ ReLU_layer(input)
+ elif inplace == False:
+ output = ReLU_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_RReLU(self, inplace):
+ bm.random.seed()
+ RReLU_layer = bp.dnn.RReLU(lower=0, upper=1, inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ RReLU_layer(input)
+ elif inplace == False:
+ output = RReLU_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_Hardtanh(self, inplace):
+ bm.random.seed()
+ Hardtanh_layer = bp.dnn.Hardtanh(min_val=0, max_val=1, inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ Hardtanh_layer(input)
+ elif inplace == False:
+ output = Hardtanh_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_ReLU6(self, inplace):
+ bm.random.seed()
+ ReLU6_layer = bp.dnn.ReLU6(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ ReLU6_layer(input)
+ elif inplace == False:
+ output = ReLU6_layer(input)
+ bm.clear_buffer_memory()
+
+ def test_Sigmoid(self):
+ bm.random.seed()
+ Sigmoid_layer = bp.dnn.Sigmoid()
+ input = bm.random.randn(2)
+ output = Sigmoid_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_Hardsigmoid(self, inplace):
+ bm.random.seed()
+ Hardsigmoid_layer = bp.dnn.Hardsigmoid(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ Hardsigmoid_layer(input)
+ elif inplace == False:
+ output = Hardsigmoid_layer(input)
+ bm.clear_buffer_memory()
+
+ def test_Tanh(self):
+ bm.random.seed()
+ Tanh_layer = bp.dnn.Tanh()
+ input = bm.random.randn(2)
+ output = Tanh_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_SiLU(self, inplace):
+ bm.random.seed()
+ SiLU_layer = bp.dnn.SiLU(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ SiLU_layer(input)
+ elif inplace == False:
+ output = SiLU_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_Mish(self, inplace):
+ bm.random.seed()
+ Mish_layer = bp.dnn.Mish(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ Mish_layer(input)
+ elif inplace == False:
+ output = Mish_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_Hardswish(self, inplace):
+ bm.random.seed()
+ Hardswish_layer = bp.dnn.Hardswish(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ Hardswish_layer(input)
+ elif inplace == False:
+ output = Hardswish_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_ELU(self, inplace):
+ bm.random.seed()
+ ELU_layer = bp.dnn.ELU(alpha=0.5, inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ ELU_layer(input)
+ elif inplace == False:
+ output = ELU_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_CELU(self, inplace):
+ bm.random.seed()
+ CELU_layer = bp.dnn.CELU(alpha=0.5, inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ CELU_layer(input)
+ elif inplace == False:
+ output = CELU_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_SELU(self, inplace):
+ bm.random.seed()
+ SELU_layer = bp.dnn.SELU(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ SELU_layer(input)
+ elif inplace == False:
+ output = SELU_layer(input)
+ bm.clear_buffer_memory()
+
+ def test_GLU(self):
+ bm.random.seed()
+ GLU_layer = bp.dnn.GLU()
+ input = bm.random.randn(4, 2)
+ output = GLU_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ approximate=['tanh', 'none']
+ )
+ def test_GELU(self, approximate):
+ bm.random.seed()
+ GELU_layer = bp.dnn.GELU()
+ input = bm.random.randn(2)
+ output = GELU_layer(input)
+ bm.clear_buffer_memory()
+
+ def test_Hardshrink(self):
+ bm.random.seed()
+ Hardshrink_layer = bp.dnn.Hardshrink(lambd=1)
+ input = bm.random.randn(2)
+ output = Hardshrink_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ inplace=[True, False]
+ )
+ def test_LeakyReLU(self, inplace):
+ bm.random.seed()
+ LeakyReLU_layer = bp.dnn.LeakyReLU(inplace=inplace)
+ input = bm.random.randn(2)
+ if inplace == True:
+ LeakyReLU_layer(input)
+ elif inplace == False:
+ output = LeakyReLU_layer(input)
+ bm.clear_buffer_memory()
+
+ def test_LogSigmoid(self):
+ bm.random.seed()
+ LogSigmoid_layer = bp.dnn.LogSigmoid()
+ input = bm.random.randn(2)
+ output = LogSigmoid_layer(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ beta=[1, 2, 3],
+ threshold=[20, 21, 22]
+ )
+ def test_Softplus(self, beta, threshold):
+ bm.random.seed()
+ Softplus_layer = bp.dnn.Softplus(beta=beta, threshold=threshold)
+ input = bm.random.randn(2)
+ output = Softplus_layer(input)
+ bm.clear_buffer_memory()
+
+ def test_Softshrink(self):
+ bm.random.seed()
+ Softshrink_layer = bp.dnn.Softshrink(lambd=1)
+ input = bm.random.randn(2)
+ output = Softshrink_layer(input)
+ bm.clear_buffer_memory()
+
+ def test_PReLU(self):
+ bm.random.seed()
+ PReLU_layer = bp.dnn.PReLU(num_parameters=2, init=0.5)
+ input = bm.random.randn(2)
+ output = PReLU_layer(input)
+ bm.clear_buffer_memory()
+
+ def test_Softsign(self):
+ bm.random.seed()
+ Softsign_layer = bp.dnn.Softsign()
+ input = bm.random.randn(2)
+ output = Softsign_layer(input)
+ bm.clear_buffer_memory()
+
+ def test_Tanhshrink(self):
+ bm.random.seed()
+ Tanhshrink_layer = bp.dnn.Tanhshrink()
+ input = bm.random.randn(2)
+ output = Tanhshrink_layer(input)
+ bm.clear_buffer_memory()
+
+ def test_Softmin(self):
+ bm.random.seed()
+ Softmin_layer = bp.dnn.Softmin(dim=2)
+ input = bm.random.randn(2, 3, 4)
+ output = Softmin_layer(input)
+ bm.clear_buffer_memory()
+
+ def test_Softmax(self):
+ bm.random.seed()
+ Softmax_layer = bp.dnn.Softmax(dim=2)
+ input = bm.random.randn(2, 3, 4)
+ output = Softmax_layer(input)
+ bm.clear_buffer_memory()
+
+ def test_Softmax2d(self):
+ bm.random.seed()
+ Softmax2d_layer = bp.dnn.Softmax2d()
+ input = bm.random.randn(2, 3, 12, 13)
+ output = Softmax2d_layer(input)
+ bm.clear_buffer_memory()
+
+ def test_LogSoftmax(self):
+ bm.random.seed()
+ LogSoftmax_layer = bp.dnn.LogSoftmax(dim=2)
+ input = bm.random.randn(2, 3, 4)
+ output = LogSoftmax_layer(input)
+ bm.clear_buffer_memory()
if __name__ == '__main__':
- absltest.main()
+ absltest.main()
diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py
index b8ebc0adb..3c9fdfa87 100644
--- a/brainpy/_src/dnn/tests/test_conv_layers.py
+++ b/brainpy/_src/dnn/tests/test_conv_layers.py
@@ -6,251 +6,265 @@
import brainpy.math as bm
from absl.testing import parameterized
import brainpy as bp
+import brainpy.math as bm
class TestConv(parameterized.TestCase):
- def test_Conv2D_img(self):
- bm.random.seed()
- img = jnp.zeros((2, 200, 198, 4))
- for k in range(4):
- x = 30 + 60 * k
- y = 20 + 60 * k
- img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
- img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
+ def test_Conv2D_img(self):
+ bm.random.seed()
+ img = jnp.zeros((2, 200, 198, 4))
+ for k in range(4):
+ x = 30 + 60 * k
+ y = 20 + 60 * k
+ img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
+ img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
- with bp.math.training_environment():
- net = bp.layers.Conv2d(in_channels=4, out_channels=32, kernel_size=(3, 3),
- strides=(2, 1), padding='VALID', groups=4)
- out = net(img)
- print("out shape: ", out.shape)
- # print("First output channel:")
- # plt.figure(figsize=(10, 10))
- # plt.imshow(np.array(img)[0, :, :, 0])
- # plt.show()
+ with bp.math.training_environment():
+ net = bp.layers.Conv2d(in_channels=4, out_channels=32, kernel_size=(3, 3),
+ strides=(2, 1), padding='VALID', groups=4)
+ out = net(img)
+ print("out shape: ", out.shape)
+ # print("First output channel:")
+ # plt.figure(figsize=(10, 10))
+ # plt.imshow(np.array(img)[0, :, :, 0])
+ # plt.show()
+ bm.clear_buffer_memory()
- def test_conv1D(self):
- bm.random.seed()
- with bp.math.training_environment():
- model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,))
+ def test_conv1D(self):
+ bm.random.seed()
+ with bp.math.training_environment():
+ model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,))
- input = bp.math.ones((2, 5, 3))
+ input = bp.math.ones((2, 5, 3))
- out = model(input)
- print("out shape: ", out.shape)
- # print("First output channel:")
- # plt.figure(figsize=(10, 10))
- # plt.imshow(np.array(out)[0, :, :])
- # plt.show()
+ out = model(input)
+ print("out shape: ", out.shape)
+ # print("First output channel:")
+ # plt.figure(figsize=(10, 10))
+ # plt.imshow(np.array(out)[0, :, :])
+ # plt.show()
+ bm.clear_buffer_memory()
- def test_conv2D(self):
- bm.random.seed()
- with bp.math.training_environment():
- model = bp.layers.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3))
+ def test_conv2D(self):
+ bm.random.seed()
+ with bp.math.training_environment():
+ model = bp.layers.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3))
- input = bp.math.ones((2, 5, 5, 3))
+ input = bp.math.ones((2, 5, 5, 3))
- out = model(input)
- print("out shape: ", out.shape)
- # print("First output channel:")
- # plt.figure(figsize=(10, 10))
- # plt.imshow(np.array(out)[0, :, :, 31])
- # plt.show()
+ out = model(input)
+ print("out shape: ", out.shape)
+ # print("First output channel:")
+ # plt.figure(figsize=(10, 10))
+ # plt.imshow(np.array(out)[0, :, :, 31])
+ # plt.show()
+ bm.clear_buffer_memory()
- def test_conv3D(self):
- bm.random.seed()
- with bp.math.training_environment():
- model = bp.layers.Conv3d(in_channels=3, out_channels=32, kernel_size=(3, 3, 3))
- input = bp.math.ones((2, 5, 5, 5, 3))
- out = model(input)
- print("out shape: ", out.shape)
+ def test_conv3D(self):
+ bm.random.seed()
+ with bp.math.training_environment():
+ model = bp.layers.Conv3d(in_channels=3, out_channels=32, kernel_size=(3, 3, 3))
+ input = bp.math.ones((2, 5, 5, 5, 3))
+ out = model(input)
+ print("out shape: ", out.shape)
+ bm.clear_buffer_memory()
class TestConvTranspose1d(parameterized.TestCase):
- def test_conv_transpose(self):
- bm.random.seed()
- x = bm.ones((1, 8, 3))
- for use_bias in [True, False]:
- conv_transpose_module = bp.layers.ConvTranspose1d(
- in_channels=3,
- out_channels=4,
- kernel_size=(3,),
- padding='VALID',
- w_initializer=bp.init.OneInit(),
- b_initializer=bp.init.OneInit() if use_bias else None,
- mode=bm.training_mode
- )
- self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
- y = conv_transpose_module(x)
- print(y.shape)
- correct_ans = jnp.array([[[4., 4., 4., 4.],
- [7., 7., 7., 7.],
- [10., 10., 10., 10.],
- [10., 10., 10., 10.],
- [10., 10., 10., 10.],
- [10., 10., 10., 10.],
- [10., 10., 10., 10.],
- [10., 10., 10., 10.],
- [7., 7., 7., 7.],
- [4., 4., 4., 4.]]])
- if not use_bias:
- correct_ans -= 1.
- self.assertTrue(bm.allclose(y, correct_ans))
+ def test_conv_transpose(self):
+ bm.random.seed()
+ x = bm.ones((1, 8, 3))
+ for use_bias in [True, False]:
+ conv_transpose_module = bp.layers.ConvTranspose1d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3,),
+ padding='VALID',
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit() if use_bias else None,
+ mode=bm.training_mode
+ )
+ self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
+ y = conv_transpose_module(x)
+ print(y.shape)
+ correct_ans = jnp.array([[[4., 4., 4., 4.],
+ [7., 7., 7., 7.],
+ [10., 10., 10., 10.],
+ [10., 10., 10., 10.],
+ [10., 10., 10., 10.],
+ [10., 10., 10., 10.],
+ [10., 10., 10., 10.],
+ [10., 10., 10., 10.],
+ [7., 7., 7., 7.],
+ [4., 4., 4., 4.]]])
+ if not use_bias:
+ correct_ans -= 1.
+ self.assertTrue(bm.allclose(y, correct_ans))
+ bm.clear_buffer_memory()
- def test_single_input_masked_conv_transpose(self):
- bm.random.seed()
- x = jnp.ones((1, 8, 3))
- m = jnp.tril(jnp.ones((3, 3, 4)))
- conv_transpose_module = bp.layers.ConvTranspose1d(
- in_channels=3,
- out_channels=4,
- kernel_size=(3,),
- padding='VALID',
- mask=m,
- w_initializer=bp.init.OneInit(),
- b_initializer=bp.init.OneInit(),
- mode=bm.batching_mode
- )
- self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
- y = conv_transpose_module(x)
- print(y.shape)
- correct_ans = jnp.array([[[4., 3., 2., 1.],
- [7., 5., 3., 1.],
- [10., 7., 4., 1.],
- [10., 7., 4., 1.],
- [10., 7., 4., 1.],
- [10., 7., 4., 1.],
- [10., 7., 4., 1.],
- [10., 7., 4., 1.],
- [7., 5., 3., 1.],
- [4., 3., 2., 1.]]])
- self.assertTrue(bm.allclose(y, correct_ans))
+ def test_single_input_masked_conv_transpose(self):
+ bm.random.seed()
+ x = jnp.ones((1, 8, 3))
+ m = jnp.tril(jnp.ones((3, 3, 4)))
+ conv_transpose_module = bp.layers.ConvTranspose1d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3,),
+ padding='VALID',
+ mask=m,
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit(),
+ mode=bm.batching_mode
+ )
+ self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
+ y = conv_transpose_module(x)
+ print(y.shape)
+ correct_ans = jnp.array([[[4., 3., 2., 1.],
+ [7., 5., 3., 1.],
+ [10., 7., 4., 1.],
+ [10., 7., 4., 1.],
+ [10., 7., 4., 1.],
+ [10., 7., 4., 1.],
+ [10., 7., 4., 1.],
+ [10., 7., 4., 1.],
+ [7., 5., 3., 1.],
+ [4., 3., 2., 1.]]])
+ self.assertTrue(bm.allclose(y, correct_ans))
+ bm.clear_buffer_memory()
- def test_computation_padding_same(self):
- bm.random.seed()
- data = jnp.ones([1, 3, 1])
- for use_bias in [True, False]:
- net = bp.layers.ConvTranspose1d(
- in_channels=1,
- out_channels=1,
- kernel_size=3,
- stride=1,
- padding="SAME",
- w_initializer=bp.init.OneInit(),
- b_initializer=bp.init.OneInit() if use_bias else None,
- mode=bm.batching_mode
- )
- out = net(data)
- self.assertEqual(out.shape, (1, 3, 1))
- out = jnp.squeeze(out, axis=(0, 2))
- expected_out = bm.as_jax([2, 3, 2])
- if use_bias:
- expected_out += 1
- self.assertTrue(bm.allclose(out, expected_out, rtol=1e-5))
+ def test_computation_padding_same(self):
+ bm.random.seed()
+ data = jnp.ones([1, 3, 1])
+ for use_bias in [True, False]:
+ net = bp.layers.ConvTranspose1d(
+ in_channels=1,
+ out_channels=1,
+ kernel_size=3,
+ stride=1,
+ padding="SAME",
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit() if use_bias else None,
+ mode=bm.batching_mode
+ )
+ out = net(data)
+ self.assertEqual(out.shape, (1, 3, 1))
+ out = jnp.squeeze(out, axis=(0, 2))
+ expected_out = bm.as_jax([2, 3, 2])
+ if use_bias:
+ expected_out += 1
+ self.assertTrue(bm.allclose(out, expected_out, rtol=1e-5))
+ bm.clear_buffer_memory()
class TestConvTranspose2d(parameterized.TestCase):
- def test_conv_transpose(self):
- bm.random.seed()
- x = bm.ones((1, 8, 8, 3))
- for use_bias in [True, False]:
- conv_transpose_module = bp.layers.ConvTranspose2d(
- in_channels=3,
- out_channels=4,
- kernel_size=(3, 3),
- padding='VALID',
- w_initializer=bp.init.OneInit(),
- b_initializer=bp.init.OneInit() if use_bias else None,
- mode=bm.training_mode
- )
- self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
- y = conv_transpose_module(x)
- print(y.shape)
+ def test_conv_transpose(self):
+ bm.random.seed()
+ x = bm.ones((1, 8, 8, 3))
+ for use_bias in [True, False]:
+ conv_transpose_module = bp.layers.ConvTranspose2d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3, 3),
+ padding='VALID',
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit() if use_bias else None,
+ mode=bm.training_mode
+ )
+ self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
+ y = conv_transpose_module(x)
+ print(y.shape)
+ bm.clear_buffer_memory()
- def test_single_input_masked_conv_transpose(self):
- bm.random.seed()
- x = jnp.ones((1, 8, 8, 3))
- m = jnp.tril(jnp.ones((3, 3, 3, 4)))
- conv_transpose_module = bp.layers.ConvTranspose2d(
- in_channels=3,
- out_channels=4,
- kernel_size=(3, 3),
- padding='VALID',
- mask=m,
- w_initializer=bp.init.OneInit(),
- mode=bm.training_mode
- )
- y = conv_transpose_module(x)
- print(y.shape)
+ def test_single_input_masked_conv_transpose(self):
+ bm.random.seed()
+ x = jnp.ones((1, 8, 8, 3))
+ m = jnp.tril(jnp.ones((3, 3, 3, 4)))
+ conv_transpose_module = bp.layers.ConvTranspose2d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3, 3),
+ padding='VALID',
+ mask=m,
+ w_initializer=bp.init.OneInit(),
+ mode=bm.training_mode
+ )
+ y = conv_transpose_module(x)
+ print(y.shape)
+ bm.clear_buffer_memory()
- def test_computation_padding_same(self):
- bm.random.seed()
- x = bm.ones((1, 8, 8, 3))
- for use_bias in [True, False]:
- conv_transpose_module = bp.layers.ConvTranspose2d(
- in_channels=3,
- out_channels=4,
- kernel_size=(3, 3),
- stride=1,
- padding='SAME',
- w_initializer=bp.init.OneInit(),
- b_initializer=bp.init.OneInit() if use_bias else None,
- mode=bm.training_mode,
- # mode=bm.nonbatching_mode,
- )
- y = conv_transpose_module(x)
- print(y.shape)
+ def test_computation_padding_same(self):
+ bm.random.seed()
+ x = bm.ones((1, 8, 8, 3))
+ for use_bias in [True, False]:
+ conv_transpose_module = bp.layers.ConvTranspose2d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3, 3),
+ stride=1,
+ padding='SAME',
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit() if use_bias else None,
+ mode=bm.training_mode,
+ # mode=bm.nonbatching_mode,
+ )
+ y = conv_transpose_module(x)
+ print(y.shape)
+ bm.clear_buffer_memory()
class TestConvTranspose3d(parameterized.TestCase):
- def test_conv_transpose(self):
- bm.random.seed()
- x = bm.ones((1, 8, 8, 8, 3))
- for use_bias in [True, False]:
- conv_transpose_module = bp.layers.ConvTranspose3d(
- in_channels=3,
- out_channels=4,
- kernel_size=(3, 3, 3),
- padding='VALID',
- w_initializer=bp.init.OneInit(),
- b_initializer=bp.init.OneInit() if use_bias else None,
- mode=bm.training_mode
- )
- y = conv_transpose_module(x)
- print(y.shape)
+ def test_conv_transpose(self):
+ bm.random.seed()
+ x = bm.ones((1, 8, 8, 8, 3))
+ for use_bias in [True, False]:
+ conv_transpose_module = bp.layers.ConvTranspose3d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3, 3, 3),
+ padding='VALID',
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit() if use_bias else None,
+ mode=bm.training_mode
+ )
+ y = conv_transpose_module(x)
+ print(y.shape)
+ bm.clear_buffer_memory()
- def test_single_input_masked_conv_transpose(self):
- bm.random.seed()
- x = jnp.ones((1, 8, 8, 8, 3))
- m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
- conv_transpose_module = bp.layers.ConvTranspose3d(
- in_channels=3,
- out_channels=4,
- kernel_size=(3, 3, 3),
- padding='VALID',
- mask=m,
- w_initializer=bp.init.OneInit(),
- mode=bm.training_mode
- )
- y = conv_transpose_module(x)
- print(y.shape)
+ def test_single_input_masked_conv_transpose(self):
+ bm.random.seed()
+ x = jnp.ones((1, 8, 8, 8, 3))
+ m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
+ conv_transpose_module = bp.layers.ConvTranspose3d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3, 3, 3),
+ padding='VALID',
+ mask=m,
+ w_initializer=bp.init.OneInit(),
+ mode=bm.training_mode
+ )
+ y = conv_transpose_module(x)
+ print(y.shape)
+ bm.clear_buffer_memory()
- def test_computation_padding_same(self):
- bm.random.seed()
- x = bm.ones((1, 8, 8, 8, 3))
- for use_bias in [True, False]:
- conv_transpose_module = bp.layers.ConvTranspose3d(
- in_channels=3,
- out_channels=4,
- kernel_size=(3, 3, 3),
- stride=1,
- padding='SAME',
- w_initializer=bp.init.OneInit(),
- b_initializer=bp.init.OneInit() if use_bias else None,
- mode=bm.training_mode
- )
- y = conv_transpose_module(x)
- print(y.shape)
+ def test_computation_padding_same(self):
+ bm.random.seed()
+ x = bm.ones((1, 8, 8, 8, 3))
+ for use_bias in [True, False]:
+ conv_transpose_module = bp.layers.ConvTranspose3d(
+ in_channels=3,
+ out_channels=4,
+ kernel_size=(3, 3, 3),
+ stride=1,
+ padding='SAME',
+ w_initializer=bp.init.OneInit(),
+ b_initializer=bp.init.OneInit() if use_bias else None,
+ mode=bm.training_mode
+ )
+ y = conv_transpose_module(x)
+ print(y.shape)
+ bm.clear_buffer_memory()
if __name__ == '__main__':
- absltest.main()
+ absltest.main()
diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py
index 90dcae17b..a686d2a41 100644
--- a/brainpy/_src/dnn/tests/test_function.py
+++ b/brainpy/_src/dnn/tests/test_function.py
@@ -11,26 +11,28 @@
class TestFunction(parameterized.TestCase):
- def test_flatten_batching_mode(self):
- bm.random.seed()
- layer = bp.dnn.Flatten(mode=bm.BatchingMode())
- input = bm.random.randn(20, 10, 10, 6)
+ def test_flatten_batching_mode(self):
+ bm.random.seed()
+ layer = bp.dnn.Flatten(mode=bm.BatchingMode())
+ input = bm.random.randn(20, 10, 10, 6)
- output = layer.update(input)
+ output = layer.update(input)
- expected_shape = (20, 600)
- self.assertEqual(output.shape, expected_shape)
+ expected_shape = (20, 600)
+ self.assertEqual(output.shape, expected_shape)
+ bm.clear_buffer_memory()
- def test_flatten_non_batching_mode(self):
- bm.random.seed()
- layer = bp.dnn.Flatten(mode=bm.NonBatchingMode())
- input = bm.random.randn(10, 10, 6)
+ def test_flatten_non_batching_mode(self):
+ bm.random.seed()
+ layer = bp.dnn.Flatten(mode=bm.NonBatchingMode())
+ input = bm.random.randn(10, 10, 6)
- output = layer.update(input)
+ output = layer.update(input)
- expected_shape = (600,)
- self.assertEqual(output.shape, expected_shape)
+ expected_shape = (600,)
+ self.assertEqual(output.shape, expected_shape)
+ bm.clear_buffer_memory()
if __name__ == '__main__':
- absltest.main()
+ absltest.main()
diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py
index 98214563a..da49bdbfe 100644
--- a/brainpy/_src/dnn/tests/test_linear.py
+++ b/brainpy/_src/dnn/tests/test_linear.py
@@ -21,6 +21,7 @@ def test_Dense1(self, size, num_out):
x = bm.random.random(size)
y = f(x)
self.assertTrue(y.shape == size[:-1] + (num_out,))
+ bm.clear_buffer_memory()
@parameterized.product(
size=[(10,),
@@ -33,6 +34,7 @@ def test_Identity(self, size):
x = bm.random.random(size)
y = f(x)
self.assertTrue(y.shape == size)
+ bm.clear_buffer_memory()
def test_AllToAll1(self):
bm.random.seed()
@@ -49,6 +51,7 @@ def test_AllToAll1(self):
y = f(x)
expected = bm.sum(x, keepdims=True) * 0.1
self.assertTrue(bm.allclose(y, expected))
+ bm.clear_buffer_memory()
def test_OneToOne(self):
bm.random.seed()
@@ -65,6 +68,7 @@ def test_OneToOne(self):
y = f(x)
expected = x * 0.1
self.assertTrue(bm.allclose(y, expected))
+ bm.clear_buffer_memory()
@parameterized.product(
conn=[
@@ -80,6 +84,7 @@ def test_MaskedLinear(self, conn):
x = bm.random.random((16, 100))
y = f(x)
self.assertTrue(y.shape == (16, 100))
+ bm.clear_buffer_memory()
@parameterized.product(
conn=[
@@ -98,6 +103,7 @@ def test_CSRLinear(self, conn):
x = bm.random.random((100,))
y = f(x)
self.assertTrue(y.shape == (100,))
+ bm.clear_buffer_memory()
@parameterized.product(
@@ -116,6 +122,7 @@ def test_EventCSRLinear(self,conn):
x = bm.random.random((100,))
y = f(x)
self.assertTrue(y.shape == (100,))
+ bm.clear_buffer_memory()
@parameterized.product(
@@ -129,6 +136,7 @@ def test_JitFPHomoLinear(self, prob, weight, shape):
x = bm.random.random(shape + (100,))
y = f(x)
self.assertTrue(y.shape == shape + (200,))
+ bm.clear_buffer_memory()
@parameterized.product(
prob=[0.01, 0.05, 0.5],
@@ -142,6 +150,7 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape):
x = bm.random.random(shape + (100,))
y = f(x)
self.assertTrue(y.shape == shape + (200,))
+ bm.clear_buffer_memory()
@parameterized.product(
prob=[0.01, 0.1, 0.5],
@@ -155,6 +164,7 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
x = bm.random.random(shape + (100,))
y = f(x)
self.assertTrue(y.shape == shape + (200,))
+ bm.clear_buffer_memory()
@parameterized.product(
prob=[0.01, 0.05, 0.5],
@@ -169,6 +179,7 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape):
y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float))
self.assertTrue(y2.shape == shape + (200,))
+ bm.clear_buffer_memory()
@parameterized.product(
prob=[0.01, 0.05, 0.5],
@@ -184,6 +195,7 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape):
y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float))
self.assertTrue(y2.shape == shape + (200,))
+ bm.clear_buffer_memory()
@parameterized.product(
prob=[0.01, 0.1, 0.5],
@@ -199,6 +211,7 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape):
y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float))
self.assertTrue(y2.shape == shape + (200,))
+ bm.clear_buffer_memory()
if __name__ == '__main__':
diff --git a/brainpy/_src/dnn/tests/test_normalization.py b/brainpy/_src/dnn/tests/test_normalization.py
index 3e4da301e..fdc5b34e3 100644
--- a/brainpy/_src/dnn/tests/test_normalization.py
+++ b/brainpy/_src/dnn/tests/test_normalization.py
@@ -5,59 +5,66 @@
class Test_Normalization(parameterized.TestCase):
- @parameterized.product(
- fit=[True, False],
- )
- def test_BatchNorm1d(self, fit):
- bm.random.seed()
- net = bp.dnn.BatchNorm1d(num_features=10, mode=bm.training_mode)
- bp.share.save(fit=fit)
- input = bm.random.randn(1, 3, 10)
- output = net(input)
-
- @parameterized.product(
- fit=[True, False]
- )
- def test_BatchNorm2d(self, fit):
- bm.random.seed()
- net = bp.dnn.BatchNorm2d(10, mode=bm.training_mode)
- bp.share.save(fit=fit)
- input = bm.random.randn(1, 3, 4, 10)
- output = net(input)
-
- @parameterized.product(
- fit=[True, False]
- )
- def test_BatchNorm3d(self, fit):
- bm.random.seed()
- net = bp.dnn.BatchNorm3d(10, mode=bm.training_mode)
- bp.share.save(fit=fit)
- input = bm.random.randn(1, 3, 4, 5, 10)
- output = net(input)
-
- @parameterized.product(
- normalized_shape=(10, [5, 10])
- )
- def test_LayerNorm(self, normalized_shape):
- bm.random.seed()
- net = bp.dnn.LayerNorm(normalized_shape, mode=bm.training_mode)
- input = bm.random.randn(20, 5, 10)
- output = net(input)
-
- @parameterized.product(
- num_groups=[1, 2, 3, 6]
- )
- def test_GroupNorm(self, num_groups):
- bm.random.seed()
- input = bm.random.randn(20, 10, 10, 6)
- net = bp.dnn.GroupNorm(num_groups=num_groups, num_channels=6, mode=bm.training_mode)
- output = net(input)
-
- def test_InstanceNorm(self):
- bm.random.seed()
- input = bm.random.randn(20, 10, 10, 6)
- net = bp.dnn.InstanceNorm(num_channels=6, mode=bm.training_mode)
- output = net(input)
+ @parameterized.product(
+ fit=[True, False],
+ )
+ def test_BatchNorm1d(self, fit):
+ bm.random.seed()
+ net = bp.dnn.BatchNorm1d(num_features=10, mode=bm.training_mode)
+ bp.share.save(fit=fit)
+ input = bm.random.randn(1, 3, 10)
+ output = net(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ fit=[True, False]
+ )
+ def test_BatchNorm2d(self, fit):
+ bm.random.seed()
+ net = bp.dnn.BatchNorm2d(10, mode=bm.training_mode)
+ bp.share.save(fit=fit)
+ input = bm.random.randn(1, 3, 4, 10)
+ output = net(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ fit=[True, False]
+ )
+ def test_BatchNorm3d(self, fit):
+ bm.random.seed()
+ net = bp.dnn.BatchNorm3d(10, mode=bm.training_mode)
+ bp.share.save(fit=fit)
+ input = bm.random.randn(1, 3, 4, 5, 10)
+ output = net(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ normalized_shape=(10, [5, 10])
+ )
+ def test_LayerNorm(self, normalized_shape):
+ bm.random.seed()
+ net = bp.dnn.LayerNorm(normalized_shape, mode=bm.training_mode)
+ input = bm.random.randn(20, 5, 10)
+ output = net(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ num_groups=[1, 2, 3, 6]
+ )
+ def test_GroupNorm(self, num_groups):
+ bm.random.seed()
+ input = bm.random.randn(20, 10, 10, 6)
+ net = bp.dnn.GroupNorm(num_groups=num_groups, num_channels=6, mode=bm.training_mode)
+ output = net(input)
+ bm.clear_buffer_memory()
+
+ def test_InstanceNorm(self):
+ bm.random.seed()
+ input = bm.random.randn(20, 10, 10, 6)
+ net = bp.dnn.InstanceNorm(num_channels=6, mode=bm.training_mode)
+ output = net(input)
+ bm.clear_buffer_memory()
+
if __name__ == '__main__':
- absltest.main()
\ No newline at end of file
+ absltest.main()
diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py
index 64a7c881d..34f8f5cd5 100644
--- a/brainpy/_src/dnn/tests/test_pooling_layers.py
+++ b/brainpy/_src/dnn/tests/test_pooling_layers.py
@@ -11,224 +11,241 @@
class TestPool(parameterized.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- def test_maxpool(self):
- bm.random.seed()
- x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
- print(jnp.arange(9).reshape(3, 3))
- print(x)
- print(x.shape)
- shared = {'fit': False}
- with bm.training_environment():
- net = bp.dnn.MaxPool((2, 2), 1, channel_axis=-1)
- y = net(shared, x)
- print("out shape: ", y.shape)
- expected_y = jnp.array([[4., 5.],
- [7., 8.]]).reshape((1, 2, 2, 1))
- np.testing.assert_allclose(y, expected_y)
-
- def test_maxpool2(self):
- bm.random.seed()
- x = bm.random.rand(10, 20, 20, 4)
- with bm.training_environment():
- net = bp.dnn.MaxPool((2, 2), (2, 2), channel_axis=-1)
- y = net(x)
- print("out shape: ", y.shape)
-
- def test_minpool(self):
- bm.random.seed()
- x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
- shared = {'fit': False}
- with bm.training_environment():
- net = bp.dnn.MinPool((2, 2), 1, channel_axis=-1)
- y = net(shared, x)
- print("out shape: ", y.shape)
- expected_y = jnp.array([
- [0., 1.],
- [3., 4.],
- ]).reshape((1, 2, 2, 1))
- np.testing.assert_allclose(y, expected_y)
-
- def test_avgpool(self):
- bm.random.seed()
- x = jnp.full((1, 3, 3, 1), 2.)
- with bm.training_environment():
- net = bp.dnn.AvgPool((2, 2), 1, channel_axis=-1)
- y = net(x)
- print("out shape: ", y.shape)
- np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))
-
- def test_MaxPool2d_v1(self):
- bm.random.seed()
- arr = bm.random.rand(16, 32, 32, 8)
-
- out = bp.dnn.MaxPool2d(2, 2, channel_axis=-1)(arr)
- self.assertTrue(out.shape == (16, 16, 16, 8))
-
- out = bp.dnn.MaxPool2d(2, 2, channel_axis=None)(arr)
- self.assertTrue(out.shape == (16, 32, 16, 4))
-
- out = bp.dnn.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr)
- self.assertTrue(out.shape == (16, 32, 17, 5))
-
- out = bp.dnn.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
- self.assertTrue(out.shape == (16, 32, 18, 5))
-
- out = bp.dnn.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
- self.assertTrue(out.shape == (16, 17, 17, 8))
-
- out = bp.dnn.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
- self.assertTrue(out.shape == (16, 17, 32, 5))
-
- def test_AvgPool2d_v1(self):
- bm.random.seed()
- arr = bm.random.rand(16, 32, 32, 8)
-
- out = bp.dnn.AvgPool2d(2, 2, channel_axis=-1)(arr)
- self.assertTrue(out.shape == (16, 16, 16, 8))
-
- out = bp.dnn.AvgPool2d(2, 2, channel_axis=None)(arr)
- self.assertTrue(out.shape == (16, 32, 16, 4))
-
- out = bp.dnn.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr)
- self.assertTrue(out.shape == (16, 32, 17, 5))
-
- out = bp.dnn.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
- self.assertTrue(out.shape == (16, 32, 18, 5))
-
- out = bp.dnn.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
- self.assertTrue(out.shape == (16, 17, 17, 8))
-
- out = bp.dnn.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
- self.assertTrue(out.shape == (16, 17, 32, 5))
-
- @parameterized.named_parameters(
- dict(testcase_name=f'target_size={target_size}',
- target_size=target_size)
- for target_size in [10, 9, 8, 7, 6]
- )
- def test_adaptive_pool1d(self, target_size):
- bm.random.seed()
- from brainpy._src.dnn.pooling import _adaptive_pool1d
-
- arr = bm.random.rand(100)
- op = jax.numpy.mean
-
- out = _adaptive_pool1d(arr, target_size, op)
- print(out.shape)
- self.assertTrue(out.shape == (target_size,))
-
- out = _adaptive_pool1d(arr, target_size, op)
- print(out.shape)
- self.assertTrue(out.shape == (target_size,))
-
- def test_AdaptiveAvgPool2d_v1(self):
- bm.random.seed()
- input = bm.random.randn(64, 8, 9)
-
- output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
- self.assertTrue(output.shape == (64, 5, 7))
-
- output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
- self.assertTrue(output.shape == (64, 2, 3))
-
- output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
- self.assertTrue(output.shape == (2, 3, 9))
-
- output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
- self.assertTrue(output.shape == (2, 8, 3))
-
- output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=None)(input)
- self.assertTrue(output.shape == (64, 2, 3))
-
- def test_AdaptiveAvgPool2d_v2(self):
- bm.random.seed()
- input = bm.random.randn(128, 64, 32, 16)
-
- output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
- self.assertTrue(output.shape == (128, 64, 5, 7))
-
- output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
- self.assertTrue(output.shape == (128, 64, 2, 3))
-
- output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
- self.assertTrue(output.shape == (128, 2, 3, 16))
-
- output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
- self.assertTrue(output.shape == (128, 64, 2, 3))
- print()
-
- def test_AdaptiveAvgPool3d_v1(self):
- bm.random.seed()
- input = bm.random.randn(10, 128, 64, 32)
- net = bp.dnn.AdaptiveAvgPool3d(target_shape=[6, 5, 3], channel_axis=0, mode=bm.nonbatching_mode)
- output = net(input)
- self.assertTrue(output.shape == (10, 6, 5, 3))
-
- def test_AdaptiveAvgPool3d_v2(self):
- bm.random.seed()
- input = bm.random.randn(10, 20, 128, 64, 32)
- net = bp.dnn.AdaptiveAvgPool3d(target_shape=[6, 5, 3], mode=bm.batching_mode)
- output = net(input)
- self.assertTrue(output.shape == (10, 6, 5, 3, 32))
-
- @parameterized.product(
- axis=(-1, 0, 1)
- )
- def test_AdaptiveMaxPool1d_v1(self, axis):
- bm.random.seed()
- input = bm.random.randn(32, 16)
- net = bp.dnn.AdaptiveMaxPool1d(target_shape=4, channel_axis=axis)
- output = net(input)
-
- @parameterized.product(
- axis=(-1, 0, 1, 2)
- )
- def test_AdaptiveMaxPool1d_v2(self, axis):
- bm.random.seed()
- input = bm.random.randn(2, 32, 16)
- net = bp.dnn.AdaptiveMaxPool1d(target_shape=4, channel_axis=axis)
- output = net(input)
-
- @parameterized.product(
- axis=(-1, 0, 1, 2)
- )
- def test_AdaptiveMaxPool2d_v1(self, axis):
- bm.random.seed()
- input = bm.random.randn(32, 16, 12)
- net = bp.dnn.AdaptiveAvgPool2d(target_shape=[5, 4], channel_axis=axis)
- output = net(input)
-
- @parameterized.product(
- axis=(-1, 0, 1, 2, 3)
- )
- def test_AdaptiveMaxPool2d_v2(self, axis):
- bm.random.seed()
- input = bm.random.randn(2, 32, 16, 12)
- net = bp.dnn.AdaptiveAvgPool2d(target_shape=[5, 4], channel_axis=axis)
- # output = net(input)
-
- @parameterized.product(
- axis=(-1, 0, 1, 2, 3)
- )
- def test_AdaptiveMaxPool3d_v1(self, axis):
- bm.random.seed()
- input = bm.random.randn(2, 128, 64, 32)
- net = bp.dnn.AdaptiveMaxPool3d(target_shape=[6, 5, 4], channel_axis=axis)
- output = net(input)
- print()
-
- @parameterized.product(
- axis=(-1, 0, 1, 2, 3, 4)
- )
- def test_AdaptiveMaxPool3d_v1(self, axis):
- bm.random.seed()
- input = bm.random.randn(2, 128, 64, 32, 16)
- net = bp.dnn.AdaptiveMaxPool3d(target_shape=[6, 5, 4], channel_axis=axis)
- output = net(input)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def test_maxpool(self):
+ bm.random.seed()
+ x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
+ print(jnp.arange(9).reshape(3, 3))
+ print(x)
+ print(x.shape)
+ shared = {'fit': False}
+ with bm.training_environment():
+ net = bp.dnn.MaxPool((2, 2), 1, channel_axis=-1)
+ y = net(shared, x)
+ print("out shape: ", y.shape)
+ expected_y = jnp.array([[4., 5.],
+ [7., 8.]]).reshape((1, 2, 2, 1))
+ np.testing.assert_allclose(y, expected_y)
+ bm.clear_buffer_memory()
+
+ def test_maxpool2(self):
+ bm.random.seed()
+ x = bm.random.rand(10, 20, 20, 4)
+ with bm.training_environment():
+ net = bp.dnn.MaxPool((2, 2), (2, 2), channel_axis=-1)
+ y = net(x)
+ print("out shape: ", y.shape)
+ bm.clear_buffer_memory()
+
+ def test_minpool(self):
+ bm.random.seed()
+ x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
+ shared = {'fit': False}
+ with bm.training_environment():
+ net = bp.dnn.MinPool((2, 2), 1, channel_axis=-1)
+ y = net(shared, x)
+ print("out shape: ", y.shape)
+ expected_y = jnp.array([
+ [0., 1.],
+ [3., 4.],
+ ]).reshape((1, 2, 2, 1))
+ np.testing.assert_allclose(y, expected_y)
+ bm.clear_buffer_memory()
+
+ def test_avgpool(self):
+ bm.random.seed()
+ x = jnp.full((1, 3, 3, 1), 2.)
+ with bm.training_environment():
+ net = bp.dnn.AvgPool((2, 2), 1, channel_axis=-1)
+ y = net(x)
+ print("out shape: ", y.shape)
+ np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))
+ bm.clear_buffer_memory()
+
+ def test_MaxPool2d_v1(self):
+ bm.random.seed()
+ arr = bm.random.rand(16, 32, 32, 8)
+
+ out = bp.dnn.MaxPool2d(2, 2, channel_axis=-1)(arr)
+ self.assertTrue(out.shape == (16, 16, 16, 8))
+
+ out = bp.dnn.MaxPool2d(2, 2, channel_axis=None)(arr)
+ self.assertTrue(out.shape == (16, 32, 16, 4))
+
+ out = bp.dnn.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr)
+ self.assertTrue(out.shape == (16, 32, 17, 5))
+
+ out = bp.dnn.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
+ self.assertTrue(out.shape == (16, 32, 18, 5))
+
+ out = bp.dnn.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
+ self.assertTrue(out.shape == (16, 17, 17, 8))
+
+ out = bp.dnn.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
+ self.assertTrue(out.shape == (16, 17, 32, 5))
+ bm.clear_buffer_memory()
+
+ def test_AvgPool2d_v1(self):
+ bm.random.seed()
+ arr = bm.random.rand(16, 32, 32, 8)
+
+ out = bp.dnn.AvgPool2d(2, 2, channel_axis=-1)(arr)
+ self.assertTrue(out.shape == (16, 16, 16, 8))
+
+ out = bp.dnn.AvgPool2d(2, 2, channel_axis=None)(arr)
+ self.assertTrue(out.shape == (16, 32, 16, 4))
+
+ out = bp.dnn.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr)
+ self.assertTrue(out.shape == (16, 32, 17, 5))
+
+ out = bp.dnn.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
+ self.assertTrue(out.shape == (16, 32, 18, 5))
+
+ out = bp.dnn.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
+ self.assertTrue(out.shape == (16, 17, 17, 8))
+
+ out = bp.dnn.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
+ self.assertTrue(out.shape == (16, 17, 32, 5))
+ bm.clear_buffer_memory()
+
+ @parameterized.named_parameters(
+ dict(testcase_name=f'target_size={target_size}',
+ target_size=target_size)
+ for target_size in [10, 9, 8, 7, 6]
+ )
+ def test_adaptive_pool1d(self, target_size):
+ bm.random.seed()
+ from brainpy._src.dnn.pooling import _adaptive_pool1d
+
+ arr = bm.random.rand(100)
+ op = jax.numpy.mean
+
+ out = _adaptive_pool1d(arr, target_size, op)
+ print(out.shape)
+ self.assertTrue(out.shape == (target_size,))
+
+ out = _adaptive_pool1d(arr, target_size, op)
+ print(out.shape)
+ self.assertTrue(out.shape == (target_size,))
+ bm.clear_buffer_memory()
+
+ def test_AdaptiveAvgPool2d_v1(self):
+ bm.random.seed()
+ input = bm.random.randn(64, 8, 9)
+
+ output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
+ self.assertTrue(output.shape == (64, 5, 7))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
+ self.assertTrue(output.shape == (64, 2, 3))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
+ self.assertTrue(output.shape == (2, 3, 9))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
+ self.assertTrue(output.shape == (2, 8, 3))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=None)(input)
+ self.assertTrue(output.shape == (64, 2, 3))
+ bm.clear_buffer_memory()
+
+ def test_AdaptiveAvgPool2d_v2(self):
+ bm.random.seed()
+ input = bm.random.randn(128, 64, 32, 16)
+
+ output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
+ self.assertTrue(output.shape == (128, 64, 5, 7))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
+ self.assertTrue(output.shape == (128, 64, 2, 3))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
+ self.assertTrue(output.shape == (128, 2, 3, 16))
+
+ output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
+ self.assertTrue(output.shape == (128, 64, 2, 3))
+ print()
+ bm.clear_buffer_memory()
+
+ def test_AdaptiveAvgPool3d_v1(self):
+ bm.random.seed()
+ input = bm.random.randn(10, 128, 64, 32)
+ net = bp.dnn.AdaptiveAvgPool3d(target_shape=[6, 5, 3], channel_axis=0, mode=bm.nonbatching_mode)
+ output = net(input)
+ self.assertTrue(output.shape == (10, 6, 5, 3))
+ bm.clear_buffer_memory()
+
+ def test_AdaptiveAvgPool3d_v2(self):
+ bm.random.seed()
+ input = bm.random.randn(10, 20, 128, 64, 32)
+ net = bp.dnn.AdaptiveAvgPool3d(target_shape=[6, 5, 3], mode=bm.batching_mode)
+ output = net(input)
+ self.assertTrue(output.shape == (10, 6, 5, 3, 32))
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ axis=(-1, 0, 1)
+ )
+ def test_AdaptiveMaxPool1d_v1(self, axis):
+ bm.random.seed()
+ input = bm.random.randn(32, 16)
+ net = bp.dnn.AdaptiveMaxPool1d(target_shape=4, channel_axis=axis)
+ output = net(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ axis=(-1, 0, 1, 2)
+ )
+ def test_AdaptiveMaxPool1d_v2(self, axis):
+ bm.random.seed()
+ input = bm.random.randn(2, 32, 16)
+ net = bp.dnn.AdaptiveMaxPool1d(target_shape=4, channel_axis=axis)
+ output = net(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ axis=(-1, 0, 1, 2)
+ )
+ def test_AdaptiveMaxPool2d_v1(self, axis):
+ bm.random.seed()
+ input = bm.random.randn(32, 16, 12)
+ net = bp.dnn.AdaptiveAvgPool2d(target_shape=[5, 4], channel_axis=axis)
+ output = net(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ axis=(-1, 0, 1, 2, 3)
+ )
+ def test_AdaptiveMaxPool2d_v2(self, axis):
+ bm.random.seed()
+ input = bm.random.randn(2, 32, 16, 12)
+ net = bp.dnn.AdaptiveAvgPool2d(target_shape=[5, 4], channel_axis=axis)
+ # output = net(input)
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ axis=(-1, 0, 1, 2, 3)
+ )
+ def test_AdaptiveMaxPool3d_v1(self, axis):
+ bm.random.seed()
+ input = bm.random.randn(2, 128, 64, 32)
+ net = bp.dnn.AdaptiveMaxPool3d(target_shape=[6, 5, 4], channel_axis=axis)
+ output = net(input)
+ print()
+ bm.clear_buffer_memory()
+
+ @parameterized.product(
+ axis=(-1, 0, 1, 2, 3, 4)
+ )
+ def test_AdaptiveMaxPool3d_v1(self, axis):
+ bm.random.seed()
+ input = bm.random.randn(2, 128, 64, 32, 16)
+ net = bp.dnn.AdaptiveMaxPool3d(target_shape=[6, 5, 4], channel_axis=axis)
+ output = net(input)
+ bm.clear_buffer_memory()
if __name__ == '__main__':
- absltest.main()
+ absltest.main()
From 59a5004b263f67cdf240cb34cd5b2a7fe988a756 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Thu, 13 Jul 2023 20:26:21 +0800
Subject: [PATCH 038/326] change `+=` to `=`
---
brainpy/_src/dyn/neurons/lif.py | 42 ++++++++++++++++-----------------
1 file changed, 21 insertions(+), 21 deletions(-)
diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py
index 62dceed37..f8ba045fd 100644
--- a/brainpy/_src/dyn/neurons/lif.py
+++ b/brainpy/_src/dyn/neurons/lif.py
@@ -225,7 +225,7 @@ def __init__(
def derivative(self, V, t, I):
for out in self.cur_inputs.values():
- I += out(V)
+ I = I + out(V)
return (-V + self.V_rest + self.R * I) / self.tau
def reset_state(self, batch_size=None):
@@ -265,7 +265,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
super().update(x)
@@ -413,7 +413,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
return super().update(x)
@@ -573,7 +573,7 @@ def __init__(
def derivative(self, V, t, I):
for out in self.cur_inputs.values():
- I += out(V)
+ I = I + out(V)
exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau
return dvdt
@@ -617,7 +617,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
return super().update(x)
@@ -740,7 +740,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
return super().update(x)
@@ -887,7 +887,7 @@ def __init__(
def dV(self, V, t, w, I):
for out in self.cur_inputs.values():
- I += out(V)
+ I = I + out(V)
exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I) / self.tau
return dVdt
@@ -951,7 +951,7 @@ def derivative(self):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
return super().update(x)
@@ -1094,7 +1094,7 @@ def derivative(self):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
return super().update(x)
@@ -1225,7 +1225,7 @@ def __init__(
def derivative(self, V, t, I):
for out in self.cur_inputs.values():
- I += out(V)
+ I = I + out(V)
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau
return dVdt
@@ -1267,7 +1267,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
return super().update(x)
@@ -1389,7 +1389,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
return super().update(x)
@@ -1536,7 +1536,7 @@ def __init__(
def dV(self, V, t, w, I):
for out in self.cur_inputs.values():
- I += out(V)
+ I = I + out(V)
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I) / self.tau
return dVdt
@@ -1598,7 +1598,7 @@ def derivative(self):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
return super().update(x)
@@ -1738,7 +1738,7 @@ def derivative(self):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
return super().update(x)
@@ -1913,7 +1913,7 @@ def dVth(self, V_th, t, V):
def dV(self, V, t, I1, I2, I):
for out in self.cur_inputs.values():
- I += out(V)
+ I = I + out(V)
return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau
@property
@@ -1982,7 +1982,7 @@ def derivative(self):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
return super().update(x)
@@ -2149,7 +2149,7 @@ def derivative(self):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
return super().update(x)
@@ -2283,7 +2283,7 @@ def __init__(
def dV(self, V, t, u, I):
for out in self.cur_inputs.values():
- I += out(V)
+ I = I + out(V)
dVdt = 0.04 * V * V + 5 * V + 140 - u + I
return dVdt
@@ -2347,7 +2347,7 @@ def derivative(self):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
return super().update(x)
@@ -2483,7 +2483,7 @@ def derivative(self):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
return super().update(x)
From b0b2df38de00413c384f23004622832e03fa7f08 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Thu, 13 Jul 2023 20:26:28 +0800
Subject: [PATCH 039/326] fix
---
brainpy/_src/math/ndarray.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py
index fe997846d..0872d0bc8 100644
--- a/brainpy/_src/math/ndarray.py
+++ b/brainpy/_src/math/ndarray.py
@@ -80,7 +80,7 @@ def __init__(self, value, dtype=None):
def _check_tracer(self):
self_value = self.value
- if hasattr(self_value, '_trace'):
+ if hasattr(self_value, '_trace') and hasattr(self_value._trace.main, 'jaxpr_stack'):
if len(self_value._trace.main.jaxpr_stack) == 0:
raise RuntimeError('This Array is modified during the transformation. '
'BrainPy only supports transformations for Variable. '
From dbd31ec903e3696c98a1608a752f7c939039ba90 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Thu, 13 Jul 2023 21:31:23 +0800
Subject: [PATCH 040/326] fix test bugs
---
.../_src/measure/tests/test_correlation.py | 55 ++++++++++++-------
1 file changed, 34 insertions(+), 21 deletions(-)
diff --git a/brainpy/_src/measure/tests/test_correlation.py b/brainpy/_src/measure/tests/test_correlation.py
index ab70463d2..8e1b17d8e 100644
--- a/brainpy/_src/measure/tests/test_correlation.py
+++ b/brainpy/_src/measure/tests/test_correlation.py
@@ -2,63 +2,74 @@
import unittest
+from functools import partial
+
+from jax import jit
+
import brainpy as bp
import brainpy.math as bm
-from jax import jit
-from functools import partial
class TestCrossCorrelation(unittest.TestCase):
def test_c(self):
- spikes = bp.math.asarray([[1, 0, 1, 0, 1, 0, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0]]).T
+ bm.random.seed()
+ spikes = bm.asarray([[1, 0, 1, 0, 1, 0, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0]]).T
cc1 = bp.measure.cross_correlation(spikes, 1., dt=1.)
f_cc = jit(partial(bp.measure.cross_correlation, numpy=False, bin=1, dt=1.))
cc2 = f_cc(spikes)
print(cc1, cc2)
self.assertTrue(cc1 == cc2)
+ bm.clear_buffer_memory()
def test_cc(self):
- spikes = bp.math.ones((1000, 10))
+ bm.random.seed()
+ spikes = bm.ones((1000, 10))
cc1 = bp.measure.cross_correlation(spikes, 1.)
self.assertTrue(cc1 == 1.)
- spikes = bp.math.zeros((1000, 10))
+ spikes = bm.zeros((1000, 10))
cc2 = bp.measure.cross_correlation(spikes, 1.)
self.assertTrue(cc2 == 0.)
+ bm.clear_buffer_memory()
+
def test_cc2(self):
- bp.math.random.seed()
- spikes = bp.math.random.randint(0, 2, (1000, 10))
+ bm.random.seed()
+ spikes = bm.random.randint(0, 2, (1000, 10))
print(bp.measure.cross_correlation(spikes, 1.))
print(bp.measure.cross_correlation(spikes, 0.5))
+ bm.clear_buffer_memory()
def test_cc3(self):
- bp.math.random.seed()
- spikes = bp.math.random.random((1000, 100)) < 0.8
+ bm.random.seed()
+ spikes = bm.random.random((1000, 100)) < 0.8
print(bp.measure.cross_correlation(spikes, 1.))
print(bp.measure.cross_correlation(spikes, 0.5))
+ bm.clear_buffer_memory()
def test_cc4(self):
- bp.math.random.seed()
- spikes = bp.math.random.random((1000, 100)) < 0.2
+ bm.random.seed()
+ spikes = bm.random.random((1000, 100)) < 0.2
print(bp.measure.cross_correlation(spikes, 1.))
print(bp.measure.cross_correlation(spikes, 0.5))
+ bm.clear_buffer_memory()
def test_cc5(self):
- bp.math.random.seed()
- spikes = bp.math.random.random((1000, 100)) < 0.05
+ bm.random.seed()
+ spikes = bm.random.random((1000, 100)) < 0.05
print(bp.measure.cross_correlation(spikes, 1.))
print(bp.measure.cross_correlation(spikes, 0.5))
+ bm.clear_buffer_memory()
class TestVoltageFluctuation(unittest.TestCase):
def test_vf1(self):
- rng = bp.math.random.RandomState(122)
+ rng = bm.random.RandomState(122)
voltages = rng.normal(0, 10, size=(1000, 100))
print(bp.measure.voltage_fluctuation(voltages))
bm.enable_x64()
- voltages = bp.math.ones((1000, 100)).value
+ voltages = bm.ones((1000, 100)).value
r1 = bp.measure.voltage_fluctuation(voltages)
jit_f = jit(partial(bp.measure.voltage_fluctuation, numpy=False))
@@ -68,30 +79,32 @@ def test_vf1(self):
# self.assertTrue(r1 == r2)
bm.disable_x64()
+ bm.clear_buffer_memory()
class TestFunctionalConnectivity(unittest.TestCase):
def test_cf1(self):
- bp.math.random.seed()
- act = bp.math.random.random((10000, 3))
+ bm.random.seed()
+ act = bm.random.random((10000, 3))
r1 = bp.measure.functional_connectivity(act)
jit_f = jit(partial(bp.measure.functional_connectivity, numpy=False))
r2 = jit_f(act)
self.assertTrue(bm.allclose(r1, r2))
+ bm.clear_buffer_memory()
class TestMatrixCorrelation(unittest.TestCase):
def test_mc(self):
- bp.math.random.seed()
- A = bp.math.random.random((100, 100))
- B = bp.math.random.random((100, 100))
+ bm.random.seed()
+ A = bm.random.random((100, 100))
+ B = bm.random.random((100, 100))
r1 = (bp.measure.matrix_correlation(A, B))
jit_f = jit(partial(bp.measure.matrix_correlation, numpy=False))
r2 = jit_f(A, B)
-
self.assertTrue(bm.allclose(r1, r2))
+ bm.clear_buffer_memory()
From 37e6a4b8ffef649a2f1466399a6c8601267da8bd Mon Sep 17 00:00:00 2001
From: chaoming
Date: Thu, 13 Jul 2023 22:02:51 +0800
Subject: [PATCH 041/326] fix test bugs
---
brainpy/__init__.py | 5 ++++-
brainpy/_src/math/op_registers/tests/test_ei_net.py | 4 +++-
brainpy/_src/math/tests/test_op_register.py | 8 ++------
brainpy/_src/math/tests/test_oprators.py | 6 ++++++
brainpy/_src/optimizers/tests/test_scheduler.py | 11 ++++++++++-
5 files changed, 25 insertions(+), 9 deletions(-)
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index 93db462d5..4b2f24822 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -63,7 +63,10 @@
Projection as Projection,
)
DynamicalSystemNS = DynamicalSystem
-
+# delays
+from brainpy._src.delay import (
+ VariableDelay as VariableDelay,
+)
# building blocks
from brainpy import (
diff --git a/brainpy/_src/math/op_registers/tests/test_ei_net.py b/brainpy/_src/math/op_registers/tests/test_ei_net.py
index 4f3da1596..24a1a6a6c 100644
--- a/brainpy/_src/math/op_registers/tests/test_ei_net.py
+++ b/brainpy/_src/math/op_registers/tests/test_ei_net.py
@@ -75,10 +75,12 @@ def __init__(self, scale):
def test1():
+ bm.random.seed()
net2 = EINet(scale=0.1)
runner2 = bp.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)])
r = runner2.predict(100., eval_time=True)
- print(r)
+ bm.clear_buffer_memory()
+
diff --git a/brainpy/_src/math/tests/test_op_register.py b/brainpy/_src/math/tests/test_op_register.py
index 4d47782a9..6917202ad 100644
--- a/brainpy/_src/math/tests/test_op_register.py
+++ b/brainpy/_src/math/tests/test_op_register.py
@@ -118,7 +118,7 @@ def test_op(self):
bm.random.seed(123)
fig, gs = bp.visualize.get_figure(1, 2, 4, 5)
- net = EINet(ExponentialSyn, scale=1., method='euler')
+ net = EINet(ExponentialSyn, scale=0.1, method='euler')
runner = bp.DSRunner(
net,
inputs=[(net.E.input, 20.), (net.I.input, 20.)],
@@ -129,7 +129,7 @@ def test_op(self):
ax = fig.add_subplot(gs[0, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], ax=ax)
- net3 = EINet(ExponentialSyn3, scale=1., method='euler')
+ net3 = EINet(ExponentialSyn3, scale=0.1, method='euler')
runner3 = bp.DSRunner(
net3,
inputs=[(net3.E.input, 20.), (net3.I.input, 20.)],
@@ -137,9 +137,5 @@ def test_op(self):
)
t, _ = runner3.run(100., eval_time=True)
print(t)
- # ax = fig.add_subplot(gs[0, 1])
- # bp.visualize.raster_plot(runner3.mon.ts, runner3.mon['E.spike'], ax=ax, show=True)
-
- # clear
plt.close()
bm.clear_buffer_memory()
diff --git a/brainpy/_src/math/tests/test_oprators.py b/brainpy/_src/math/tests/test_oprators.py
index a0bd8dbe9..42bdcb95e 100644
--- a/brainpy/_src/math/tests/test_oprators.py
+++ b/brainpy/_src/math/tests/test_oprators.py
@@ -35,30 +35,35 @@ def test_syn2post_sum(self):
segment_ids = bm.array([0, 0, 1, 1, 2])
self.assertTrue(bm.array_equal(bm.syn2post_sum(data, segment_ids, 3),
bm.asarray([1, 5, 4])))
+ bm.clear_buffer_memory()
def test_syn2post_max(self):
data = bm.arange(5)
segment_ids = bm.array([0, 0, 1, 1, 2])
self.assertTrue(bm.array_equal(bm.syn2post_max(data, segment_ids, 3),
bm.asarray([1, 3, 4])))
+ bm.clear_buffer_memory()
def test_syn2post_min(self):
data = bm.arange(5)
segment_ids = bm.array([0, 0, 1, 1, 2])
self.assertTrue(bm.array_equal(bm.syn2post_min(data, segment_ids, 3),
bm.asarray([0, 2, 4])))
+ bm.clear_buffer_memory()
def test_syn2post_prod(self):
data = bm.arange(5)
segment_ids = bm.array([0, 0, 1, 1, 2])
self.assertTrue(bm.array_equal(bm.syn2post_prod(data, segment_ids, 3),
bm.asarray([0, 6, 4])))
+ bm.clear_buffer_memory()
def test_syn2post_mean(self):
data = bm.arange(5)
segment_ids = bm.array([0, 0, 1, 1, 2])
self.assertTrue(bm.array_equal(bm.syn2post_mean(data, segment_ids, 3),
bm.asarray([0.5, 2.5, 4.])))
+ bm.clear_buffer_memory()
def test_syn2post_softmax(self):
data = bm.arange(5)
@@ -79,6 +84,7 @@ def test_syn2post_softmax(self):
data = bm.arange(5)
segment_ids = bm.array([0, 0, 1, 1, 2])
print(bm.syn2post_softmax(data, segment_ids, 4))
+ bm.clear_buffer_memory()
#
# class TestSparseMatmul(unittest.TestCase):
diff --git a/brainpy/_src/optimizers/tests/test_scheduler.py b/brainpy/_src/optimizers/tests/test_scheduler.py
index e614ccca1..d283c2ff1 100644
--- a/brainpy/_src/optimizers/tests/test_scheduler.py
+++ b/brainpy/_src/optimizers/tests/test_scheduler.py
@@ -5,6 +5,7 @@
import jax.numpy
import matplotlib.pyplot as plt
from absl.testing import parameterized
+import brainpy.math as bm
from brainpy._src.optimizers import scheduler
@@ -17,6 +18,7 @@ class TestMultiStepLR(parameterized.TestCase):
last_epoch=[-1, 0, 5, 10]
)
def test2(self, last_epoch):
+ bm.random.seed()
scheduler1 = scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch)
scheduler2 = scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch)
@@ -26,6 +28,8 @@ def test2(self, last_epoch):
scheduler2.step_epoch()
print(f'{scheduler2.last_epoch}, {lr1:.4f}, {lr2:.4f}')
self.assertTrue(lr1 == lr2)
+ bm.clear_buffer_memory()
+
class TestStepLR(parameterized.TestCase):
@@ -36,6 +40,7 @@ class TestStepLR(parameterized.TestCase):
for last_epoch in [-1, 0, 5, 10]
)
def test1(self, last_epoch):
+ bm.random.seed()
scheduler1 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch)
scheduler2 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch)
@@ -45,10 +50,12 @@ def test1(self, last_epoch):
scheduler2.step_epoch()
print(f'{scheduler2.last_epoch}, {lr1:.4f}, {lr2:.4f}')
self.assertTrue(lr1 == lr2)
+ bm.clear_buffer_memory()
class TestCosineAnnealingLR(unittest.TestCase):
def test1(self):
+ bm.random.seed()
max_epoch = 50
iters = 200
sch = scheduler.CosineAnnealingLR(0.1, T_max=5, eta_min=0, last_epoch=-1)
@@ -70,10 +77,12 @@ def test1(self):
plt.plot(jax.numpy.asarray(all_lr2[0]), jax.numpy.asarray(all_lr2[1]))
plt.show()
plt.close()
+ bm.clear_buffer_memory()
class TestCosineAnnealingWarmRestarts(unittest.TestCase):
def test1(self):
+ bm.random.seed()
max_epoch = 50
iters = 200
sch = scheduler.CosineAnnealingWarmRestarts(0.1,
@@ -97,5 +106,5 @@ def test1(self):
plt.plot(jax.numpy.asarray(all_lr2))
plt.show()
plt.close()
-
+ bm.clear_buffer_memory()
From d9c4a3cd3fdfd06502c35abbc7f128ed6fe03561 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Thu, 13 Jul 2023 22:08:27 +0800
Subject: [PATCH 042/326] fix test bugs
---
.../_src/integrators/fde/tests/test_Caputo.py | 2 ++
brainpy/_src/integrators/fde/tests/test_GL.py | 2 ++
.../integrators/ode/tests/test_delay_ode.py | 6 ++++
.../ode/tests/test_ode_method_adaptive_rk.py | 2 ++
.../ode/tests/test_ode_method_exp_euler.py | 2 ++
.../ode/tests/test_ode_method_rk.py | 2 ++
.../_src/integrators/sde/tests/test_normal.py | 17 ++++++++---
.../integrators/tests/test_integ_runner.py | 29 ++++++++++---------
.../_src/optimizers/tests/test_scheduler.py | 5 ----
9 files changed, 44 insertions(+), 23 deletions(-)
diff --git a/brainpy/_src/integrators/fde/tests/test_Caputo.py b/brainpy/_src/integrators/fde/tests/test_Caputo.py
index 4948fe770..15101d6a8 100644
--- a/brainpy/_src/integrators/fde/tests/test_Caputo.py
+++ b/brainpy/_src/integrators/fde/tests/test_Caputo.py
@@ -10,6 +10,7 @@
class TestCaputoL1(unittest.TestCase):
def test1(self):
+ bp.math.random.seed()
bp.math.enable_x64()
alpha = 0.9
intg = bp.fde.CaputoL1Schema(lambda a, t: a,
@@ -32,4 +33,5 @@ def test1(self):
print(memory_trace[0], )
print(memory_trace2[0], bp.math.array_equal(memory_trace[0], memory_trace2[0]))
+ bp.math.clear_buffer_memory()
bp.math.disable_x64()
diff --git a/brainpy/_src/integrators/fde/tests/test_GL.py b/brainpy/_src/integrators/fde/tests/test_GL.py
index f5bdb09ed..1b8217a07 100644
--- a/brainpy/_src/integrators/fde/tests/test_GL.py
+++ b/brainpy/_src/integrators/fde/tests/test_GL.py
@@ -20,6 +20,7 @@ def lorenz(x, y, z, t):
dz = x * y - c * z
return dx, dy, dz
+ bp.math.random.seed()
integral = bp.fde.GLShortMemory(lorenz,
alpha=0.99,
num_memory=500,
@@ -32,5 +33,6 @@ def lorenz(x, y, z, t):
plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten())
plt.show(block=block)
+ bp.math.clear_buffer_memory()
diff --git a/brainpy/_src/integrators/ode/tests/test_delay_ode.py b/brainpy/_src/integrators/ode/tests/test_delay_ode.py
index 4efce9cc6..991bf0ce0 100644
--- a/brainpy/_src/integrators/ode/tests/test_delay_ode.py
+++ b/brainpy/_src/integrators/ode/tests/test_delay_ode.py
@@ -62,6 +62,7 @@ def __init__(self, *args, **kwargs):
for name in get_supported_methods()
)
def test1(self, method):
+ bm.random.seed()
case1_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round')
case2_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='linear_interp')
@@ -87,6 +88,8 @@ def test1(self, method):
# plt.show(block=block)
# plt.close()
+ bm.clear_buffer_memory()
+
class TestNonConstantHist(parameterized.TestCase):
def get_eq(self, xdelay):
@@ -102,6 +105,8 @@ def __init__(self, *args, **kwargs):
for name in get_supported_methods()
)
def test1(self, method):
+ bm.random.seed()
+
delay1 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: jnp.exp(-t) - 1, dt=0.01, interp_method='round')
delay2 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: jnp.exp(-t) - 1, dt=0.01)
case1 = delay_odeint(4., self.get_eq(delay1), state_delays={'x': delay1}, dt=0.01, method=method)
@@ -114,3 +119,4 @@ def test1(self, method):
# self.assertTrue((case1['x'] - self.ref1['x']).mean() < 1e-1)
# self.assertTrue((case2['x'] - self.ref2['x']).mean() < 1e-1)
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/integrators/ode/tests/test_ode_method_adaptive_rk.py b/brainpy/_src/integrators/ode/tests/test_ode_method_adaptive_rk.py
index 6edb75862..d9cc1cbf2 100644
--- a/brainpy/_src/integrators/ode/tests/test_ode_method_adaptive_rk.py
+++ b/brainpy/_src/integrators/ode/tests/test_ode_method_adaptive_rk.py
@@ -66,4 +66,6 @@ def test_all_methods(self):
adaptive_rk.CashKarp,
adaptive_rk.BogackiShampine,
adaptive_rk.HeunEuler]:
+ bm.random.seed()
run_integrator(method, show=False)
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py b/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py
index 2b8dd6781..42ad7f487 100644
--- a/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py
+++ b/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py
@@ -103,6 +103,7 @@ def update(self, tdi):
self.n.value = n
self.input[:] = 0.
+ bm.random.seed()
hh1 = HH(1, method='exp_euler')
runner1 = bp.DSRunner(hh1, inputs=('input', 2.), monitors=['V', 'h', 'n'])
runner1.run(100)
@@ -125,4 +126,5 @@ def update(self, tdi):
self.assertTrue(diff < 1e0)
plt.close()
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/integrators/ode/tests/test_ode_method_rk.py b/brainpy/_src/integrators/ode/tests/test_ode_method_rk.py
index 08a7a5936..a8e5535ab 100644
--- a/brainpy/_src/integrators/ode/tests/test_ode_method_rk.py
+++ b/brainpy/_src/integrators/ode/tests/test_ode_method_rk.py
@@ -74,7 +74,9 @@ def test_all_methods(self):
explicit_rk.RK4,
explicit_rk.Ralston4,
explicit_rk.RK4Rule38]:
+ bm.random.seed()
mon_x, mon_y, mon_z = run_integrator(method)
assert np.linalg.norm(mon_x - _baseline_x) / (duration / dt) < 0.1
assert np.linalg.norm(mon_y - _baseline_y) / (duration / dt) < 0.1
assert np.linalg.norm(mon_z - _baseline_z) / (duration / dt) < 0.1
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/integrators/sde/tests/test_normal.py b/brainpy/_src/integrators/sde/tests/test_normal.py
index 5a15a9680..503161b31 100644
--- a/brainpy/_src/integrators/sde/tests/test_normal.py
+++ b/brainpy/_src/integrators/sde/tests/test_normal.py
@@ -4,6 +4,7 @@
import unittest
import brainpy as bp
+import brainpy.math as bm
import matplotlib.pyplot as plt
from brainpy._src.integrators.sde.normal import ExponentialEuler
@@ -21,21 +22,24 @@ def lorenz_g(x, y, z, t, **kwargs):
dy = lambda y, t, x, z, rho=28: x * (rho - z) - y
dz = lambda z, t, x, y, beta=8 / 3: x * y - beta * z
+ bm.random.seed()
intg = ExponentialEuler(f=bp.JointEq([dx, dy, dz]),
g=lorenz_g,
intg_type=bp.integrators.ITO_SDE,
wiener_type=bp.integrators.SCALAR_WIENER,
var_type=bp.integrators.POP_VAR,
show_code=True)
- runner = bp.integrators.IntegratorRunner(intg,
- monitors=['x', 'y', 'z'],
- dt=0.001, inits=[1., 1., 0.])
+ runner = bp.IntegratorRunner(intg,
+ monitors=['x', 'y', 'z'],
+ dt=0.001, inits=[1., 1., 0.])
runner.run(100.)
plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten())
if show:
plt.show()
plt.close()
+ bm.clear_buffer_memory()
+
def test2(self):
p = 0.1
@@ -50,6 +54,7 @@ def lorenz_g(x, y, z, t, **kwargs):
dy = lambda y, t, x, z, rho=28: x * (rho - z) - y
dz = lambda z, t, x, y, beta=8 / 3: x * y - beta * z
+ bm.random.seed()
intg = ExponentialEuler(f=bp.JointEq([dx, dy, dz]),
g=lorenz_g,
intg_type=bp.integrators.ITO_SDE,
@@ -60,6 +65,7 @@ def lorenz_g(x, y, z, t, **kwargs):
dt=0.001, inits=[1., 1., 0.], jit=False)
with self.assertRaises(ValueError):
runner.run(100.)
+ bm.clear_buffer_memory()
def test3(self):
p = 0.1
@@ -70,6 +76,7 @@ def lorenz_g(x, y, z, t, **kwargs):
bp.math.asarray([p * y, p2 * y]).T, \
bp.math.asarray([p * z, p2 * z]).T
+ bm.random.seed()
dx = lambda x, t, y, sigma=10: sigma * (y - x)
dy = lambda y, t, x, z, rho=28: x * (rho - z) - y
dz = lambda z, t, x, y, beta=8 / 3: x * y - beta * z
@@ -91,6 +98,7 @@ def lorenz_g(x, y, z, t, **kwargs):
if show:
plt.show()
plt.close()
+ bm.clear_buffer_memory()
class TestMilstein(unittest.TestCase):
@@ -108,6 +116,7 @@ def test1(self):
fy = lambda y, t, x, z: x * (rho - z) - y
fz = lambda z, t, x, y: x * y - beta * z
+ bm.random.seed()
intg = bp.sdeint(f=bp.JointEq(fx, fy, fz),
g=bp.JointEq(gx, gy, gz),
intg_type=bp.integrators.ITO_SDE,
@@ -124,4 +133,4 @@ def test1(self):
if show:
plt.show()
plt.close()
-
+ bm.clear_buffer_memory()
diff --git a/brainpy/_src/integrators/tests/test_integ_runner.py b/brainpy/_src/integrators/tests/test_integ_runner.py
index 6633a8161..353735184 100644
--- a/brainpy/_src/integrators/tests/test_integ_runner.py
+++ b/brainpy/_src/integrators/tests/test_integ_runner.py
@@ -10,6 +10,7 @@
class TestIntegratorRunnerForODEs(TestCase):
def test_ode(self):
+
sigma = 10
beta = 8 / 3
rho = 28
@@ -21,16 +22,16 @@ def lorenz(x, y, z, t):
dz = x * y - beta * z
return dx, dy, dz
- runner = bp.integrators.IntegratorRunner(lorenz, monitors=['x', 'y', 'z'], inits=[1., 1., 1.])
+ runner = bp.IntegratorRunner(lorenz, monitors=['x', 'y', 'z'], inits=[1., 1., 1.])
runner.run(100.)
fig = plt.figure()
fig.add_subplot(111, projection='3d')
plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0], )
plt.show()
- runner = bp.integrators.IntegratorRunner(lorenz,
- monitors=['x', 'y', 'z'],
- inits=[1., (1., 0.), (1., 0.)])
+ runner = bp.IntegratorRunner(lorenz,
+ monitors=['x', 'y', 'z'],
+ inits=[1., (1., 0.), (1., 0.)])
runner.run(100.)
for i in range(2):
fig = plt.figure()
@@ -47,7 +48,7 @@ def test_ode2(self):
dw = lambda w, t, V: (V + a - b * w) / tau
fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1)
- runner = bp.integrators.IntegratorRunner(fhn, monitors=['V', 'w'], inits=[1., 1.])
+ runner = bp.IntegratorRunner(fhn, monitors=['V', 'w'], inits=[1., 1.])
runner.run(100., args=dict(Iext=1.5))
bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V')
bp.visualize.line_plot(runner.mon.ts, runner.mon['w'], legend='w', show=True)
@@ -61,9 +62,9 @@ def test_ode3(self):
fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1)
Iext, duration = bp.inputs.section_input([0., 1., 0.5], [200, 500, 200], return_length=True)
- runner = bp.integrators.IntegratorRunner(fhn,
- monitors=['V', 'w'],
- inits=[1., 1.])
+ runner = bp.IntegratorRunner(fhn,
+ monitors=['V', 'w'],
+ inits=[1., 1.])
runner.run(duration, dyn_args=dict(Iext=Iext))
bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V')
bp.visualize.line_plot(runner.mon.ts, runner.mon['w'], legend='w', show=True)
@@ -76,9 +77,9 @@ def test_ode_continuous_run(self):
dw = lambda w, t, V: (V + a - b * w) / tau
fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1)
- runner = bp.integrators.IntegratorRunner(fhn,
- monitors=['V', 'w'],
- inits=[1., 1.])
+ runner = bp.IntegratorRunner(fhn,
+ monitors=['V', 'w'],
+ inits=[1., 1.])
Iext, duration = bp.inputs.section_input([0., 1., 0.5], [200, 200, 200], return_length=True)
runner.run(duration, dyn_args=dict(Iext=Iext))
bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V')
@@ -100,9 +101,9 @@ def test_ode_dyn_args(self):
Iext, duration = bp.inputs.section_input([0., 1., 0.5],
[200, 500, 199],
return_length=True)
- runner = bp.integrators.IntegratorRunner(fhn,
- monitors=['V', 'w'],
- inits=[1., 1.])
+ runner = bp.IntegratorRunner(fhn,
+ monitors=['V', 'w'],
+ inits=[1., 1.])
with self.assertRaises(ValueError):
runner.run(duration + 1, dyn_args=dict(Iext=Iext))
diff --git a/brainpy/_src/optimizers/tests/test_scheduler.py b/brainpy/_src/optimizers/tests/test_scheduler.py
index d283c2ff1..f08ed9233 100644
--- a/brainpy/_src/optimizers/tests/test_scheduler.py
+++ b/brainpy/_src/optimizers/tests/test_scheduler.py
@@ -13,7 +13,6 @@
class TestMultiStepLR(parameterized.TestCase):
-
@parameterized.product(
last_epoch=[-1, 0, 5, 10]
)
@@ -31,9 +30,7 @@ def test2(self, last_epoch):
bm.clear_buffer_memory()
-
class TestStepLR(parameterized.TestCase):
-
@parameterized.named_parameters(
{'testcase_name': f'last_epoch={last_epoch}',
'last_epoch': last_epoch}
@@ -43,7 +40,6 @@ def test1(self, last_epoch):
bm.random.seed()
scheduler1 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch)
scheduler2 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch)
-
for i in range(1, 25):
lr1 = scheduler1(i + last_epoch)
lr2 = scheduler2()
@@ -107,4 +103,3 @@ def test1(self):
plt.show()
plt.close()
bm.clear_buffer_memory()
-
From df1379529b33f55d3a7ef1bb948f2aa28e8d61a4 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Thu, 13 Jul 2023 22:43:16 +0800
Subject: [PATCH 043/326] fix test bugs
---
brainpy/_src/integrators/sde/tests/test_sde_scalar.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/brainpy/_src/integrators/sde/tests/test_sde_scalar.py b/brainpy/_src/integrators/sde/tests/test_sde_scalar.py
index 6f9fae51a..813bb935b 100644
--- a/brainpy/_src/integrators/sde/tests/test_sde_scalar.py
+++ b/brainpy/_src/integrators/sde/tests/test_sde_scalar.py
@@ -2,13 +2,12 @@
import unittest
+import matplotlib.pyplot as plt
import numpy as np
import pytest
import brainpy as bp
from brainpy.integrators import sde
-import matplotlib.pyplot as plt
-
block = False
sigma = 10
@@ -29,6 +28,7 @@ def lorenz_g(x, y, z, t):
def lorenz_system(method, **kwargs):
+ bp.math.seed()
integral = bp.math.jit(method(f=lorenz_f,
g=lorenz_g,
show_code=True,
@@ -57,6 +57,7 @@ def lorenz_system(method, **kwargs):
ax.set_xlabel('z')
plt.show(block=block)
plt.close(fig)
+ bp.math.clear_buffer_memory()
class TestScalarWienerIntegral(unittest.TestCase):
From 19cc77c43107049b0b96ec227a908cb2c3df7219 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Thu, 13 Jul 2023 23:11:28 +0800
Subject: [PATCH 044/326] fix test bugs
---
brainpy/_src/integrators/sde/tests/test_sde_scalar.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/brainpy/_src/integrators/sde/tests/test_sde_scalar.py b/brainpy/_src/integrators/sde/tests/test_sde_scalar.py
index 813bb935b..f9d4e4e5f 100644
--- a/brainpy/_src/integrators/sde/tests/test_sde_scalar.py
+++ b/brainpy/_src/integrators/sde/tests/test_sde_scalar.py
@@ -28,7 +28,7 @@ def lorenz_g(x, y, z, t):
def lorenz_system(method, **kwargs):
- bp.math.seed()
+ bp.math.random.seed()
integral = bp.math.jit(method(f=lorenz_f,
g=lorenz_g,
show_code=True,
From f4ff69a18a5ac925de57aceaab33301e80ef0248 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 14 Jul 2023 13:15:32 +0800
Subject: [PATCH 045/326] fix test bugs
---
brainpy/_src/measure/correlation.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/brainpy/_src/measure/correlation.py b/brainpy/_src/measure/correlation.py
index 9e3dd9d0a..5cfd1b0d1 100644
--- a/brainpy/_src/measure/correlation.py
+++ b/brainpy/_src/measure/correlation.py
@@ -107,6 +107,10 @@ def _cc(i, j):
return np.mean(np.asarray(res))
+def _f_signal(signal):
+ return jnp.mean(signal * signal) - jnp.mean(signal) ** 2
+
+
def voltage_fluctuation(potentials, numpy=True, method='loop'):
r"""Calculate neuronal synchronization via voltage variance.
@@ -177,15 +181,14 @@ def voltage_fluctuation(potentials, numpy=True, method='loop'):
avg_var = jnp.mean(avg * avg) - jnp.mean(avg) ** 2
if method == 'loop':
- _var = lambda aa: bm.for_loop(lambda signal: jnp.mean(signal * signal) - jnp.mean(signal) ** 2,
- operands=jnp.moveaxis(aa, 0, 1))
+ _var = bm.for_loop(_f_signal, operands=jnp.moveaxis(potentials, 0, 1))
elif method == 'vmap':
- _var = vmap(lambda signal: jnp.mean(signal * signal) - jnp.mean(signal) ** 2, in_axes=1)
+ _var = vmap(_f_signal, in_axes=1)(potentials)
else:
raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".')
- var_mean = jnp.mean(_var(potentials))
+ var_mean = jnp.mean(_var)
r = jnp.where(var_mean == 0., 1., avg_var / var_mean)
return bm.as_numpy(r) if numpy else r
From e6d6892fb1da89ea60186ddb771b69d61f87858e Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 14 Jul 2023 13:55:47 +0800
Subject: [PATCH 046/326] fix test bugs
---
brainpy/_src/measure/correlation.py | 34 +++++++------------
.../_src/measure/tests/test_correlation.py | 6 ++--
docs/tutorial_advanced/analysis.rst | 2 +-
3 files changed, 16 insertions(+), 26 deletions(-)
diff --git a/brainpy/_src/measure/correlation.py b/brainpy/_src/measure/correlation.py
index 5cfd1b0d1..d0d7db17e 100644
--- a/brainpy/_src/measure/correlation.py
+++ b/brainpy/_src/measure/correlation.py
@@ -147,33 +147,23 @@ def voltage_fluctuation(potentials, numpy=True, method='loop'):
\chi^2 \left( N \right) = \frac{\sigma_V^2}{ \frac{1}{N} \sum_{i=1}^N
\sigma_{V_i}^2}
- Parameters
- ----------
- potentials : ndarray
- The membrane potential matrix of the neuron group.
- numpy: bool
- Whether we use numpy array as the functional output.
- If ``False``, this function can be JIT compiled.
- method: str
- The method to calculate all pairs of cross correlation.
- Supports two kinds of methods: `loop` and `vmap`.
- `vmap` method will consume much more memory.
-
- .. versionadded:: 2.2.3.4
-
-
- Returns
- -------
- sync_index : float
- The synchronization index.
-
- References
- ----------
.. [1] Golomb, D. and Rinzel J. (1993) Dynamics of globally coupled
inhibitory neurons with heterogeneity. Phys. Rev. E 48:4810-4814.
.. [2] Golomb D. and Rinzel J. (1994) Clustering in globally coupled
inhibitory neurons. Physica D 72:259-282.
.. [3] David Golomb (2007) Neuronal synchrony measures. Scholarpedia, 2(1):1347.
+
+ Args:
+ potentials: The membrane potential matrix of the neuron group.
+ numpy: Whether we use numpy array as the functional output. If ``False``, this function can be JIT compiled.
+ method: The method to calculate all pairs of cross correlation.
+ Supports two kinds of methods: `loop` and `vmap`.
+ `vmap` method will consume much more memory.
+
+ .. versionadded:: 2.2.3.4
+
+ Returns:
+ sync_index: The synchronization index.
"""
potentials = bm.as_jax(potentials)
diff --git a/brainpy/_src/measure/tests/test_correlation.py b/brainpy/_src/measure/tests/test_correlation.py
index 8e1b17d8e..d9ed7519b 100644
--- a/brainpy/_src/measure/tests/test_correlation.py
+++ b/brainpy/_src/measure/tests/test_correlation.py
@@ -64,12 +64,12 @@ def test_cc5(self):
class TestVoltageFluctuation(unittest.TestCase):
def test_vf1(self):
- rng = bm.random.RandomState(122)
- voltages = rng.normal(0, 10, size=(1000, 100))
+ bm.random.seed()
+ voltages = bm.random.normal(0, 10, size=(100, 10))
print(bp.measure.voltage_fluctuation(voltages))
bm.enable_x64()
- voltages = bm.ones((1000, 100)).value
+ voltages = bm.ones((100, 10)).value
r1 = bp.measure.voltage_fluctuation(voltages)
jit_f = jit(partial(bp.measure.voltage_fluctuation, numpy=False))
diff --git a/docs/tutorial_advanced/analysis.rst b/docs/tutorial_advanced/analysis.rst
index 29d8d3886..f574fdb5b 100644
--- a/docs/tutorial_advanced/analysis.rst
+++ b/docs/tutorial_advanced/analysis.rst
@@ -1,4 +1,4 @@
-Interoperation
+Analysis
================
.. toctree::
From 04858cbe61961cf609730f70c42b09671cbd270f Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 14 Jul 2023 14:30:36 +0800
Subject: [PATCH 047/326] remove windows tests
---
.github/workflows/CI.yml | 68 +++++++++----------
.../_src/measure/tests/test_correlation.py | 2 +-
.../_src/optimizers/tests/test_scheduler.py | 2 +-
brainpy/_src/running/jax_multiprocessing.py | 4 +-
4 files changed, 39 insertions(+), 37 deletions(-)
diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 7f8fc93c3..b8a43c38c 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -151,40 +151,40 @@ jobs:
#
- test_windows:
- runs-on: windows-latest
- strategy:
- fail-fast: false
- matrix:
- python-version: ["3.8", "3.9", "3.10", "3.11"]
-
- steps:
- - uses: actions/checkout@v2
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- python -m pip install flake8 pytest
- python -m pip install numpy>=1.21.0
- python -m pip install "jaxlib==0.4.11" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
- python -m pip install jax==0.4.11
- python -m pip install -r requirements-dev.txt
- python -m pip install tqdm brainpylib
- pip uninstall brainpy -y
- python setup.py install
- - name: Lint with flake8
- run: |
- # stop the build if there are Python syntax errors or undefined names
- flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics
- # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- - name: Test with pytest
- run: |
- cd brainpy
- pytest _src/
+# test_windows:
+# runs-on: windows-latest
+# strategy:
+# fail-fast: false
+# matrix:
+# python-version: ["3.8", "3.9", "3.10", "3.11"]
+#
+# steps:
+# - uses: actions/checkout@v2
+# - name: Set up Python ${{ matrix.python-version }}
+# uses: actions/setup-python@v2
+# with:
+# python-version: ${{ matrix.python-version }}
+# - name: Install dependencies
+# run: |
+# python -m pip install --upgrade pip
+# python -m pip install flake8 pytest
+# python -m pip install numpy>=1.21.0
+# python -m pip install "jaxlib==0.4.11" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
+# python -m pip install jax==0.4.11
+# python -m pip install -r requirements-dev.txt
+# python -m pip install tqdm brainpylib
+# pip uninstall brainpy -y
+# python setup.py install
+# - name: Lint with flake8
+# run: |
+# # stop the build if there are Python syntax errors or undefined names
+# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics
+# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
+# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
+# - name: Test with pytest
+# run: |
+# cd brainpy
+# pytest _src/
# test_windows_py37:
diff --git a/brainpy/_src/measure/tests/test_correlation.py b/brainpy/_src/measure/tests/test_correlation.py
index d9ed7519b..950dbce1f 100644
--- a/brainpy/_src/measure/tests/test_correlation.py
+++ b/brainpy/_src/measure/tests/test_correlation.py
@@ -69,7 +69,7 @@ def test_vf1(self):
print(bp.measure.voltage_fluctuation(voltages))
bm.enable_x64()
- voltages = bm.ones((100, 10)).value
+ voltages = bm.ones((100, 10))
r1 = bp.measure.voltage_fluctuation(voltages)
jit_f = jit(partial(bp.measure.voltage_fluctuation, numpy=False))
diff --git a/brainpy/_src/optimizers/tests/test_scheduler.py b/brainpy/_src/optimizers/tests/test_scheduler.py
index f08ed9233..dbdda0eda 100644
--- a/brainpy/_src/optimizers/tests/test_scheduler.py
+++ b/brainpy/_src/optimizers/tests/test_scheduler.py
@@ -5,8 +5,8 @@
import jax.numpy
import matplotlib.pyplot as plt
from absl.testing import parameterized
-import brainpy.math as bm
+import brainpy.math as bm
from brainpy._src.optimizers import scheduler
show = False
diff --git a/brainpy/_src/running/jax_multiprocessing.py b/brainpy/_src/running/jax_multiprocessing.py
index 719c36953..3520d809f 100644
--- a/brainpy/_src/running/jax_multiprocessing.py
+++ b/brainpy/_src/running/jax_multiprocessing.py
@@ -60,8 +60,10 @@ def jax_vectorize_map(
run_f = vmap(func) if clear_buffer else vmap_func
if isinstance(arguments, dict):
r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
- else:
+ elif isinstance(arguments, (tuple, list)):
r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
+ else:
+ raise TypeError
res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.Array))
if results is None:
results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values)
From 3bdc8b03a9bacd49f901d5e5229948342ef21af8 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Wed, 19 Jul 2023 22:50:45 +0800
Subject: [PATCH 048/326] add more projection types
---
brainpy/_src/dyn/projections/aligns.py | 688 +++++++++++++++++++++++--
brainpy/dyn/projections.py | 9 +-
2 files changed, 646 insertions(+), 51 deletions(-)
diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py
index 925d7dd22..907a144f2 100644
--- a/brainpy/_src/dyn/projections/aligns.py
+++ b/brainpy/_src/dyn/projections/aligns.py
@@ -1,15 +1,21 @@
from typing import Optional, Callable, Union
-from brainpy import math as bm
-from brainpy._src.delay import Delay, VariableDelay, DataDelay
-from brainpy._src.dynsys import DynamicalSystem, Projection, Dynamic
+import jax
+
+from brainpy import math as bm, check
+from brainpy._src.delay import Delay, VariDelay, DataDelay, DelayAccess
+from brainpy._src.dynsys import DynamicalSystem, Projection, Dynamic, Sequential
from brainpy._src.mixin import JointType, ParamDescInit, ReturnInfo, AutoDelaySupp, BindCondData, AlignPost
__all__ = [
- 'ProjAlignPre',
- 'ProjAlignPost',
+ 'VanillaProj',
+ 'ProjAlignPostMg1', 'ProjAlignPostMg2',
+ 'ProjAlignPost1', 'ProjAlignPost2',
+ 'ProjAlignPreMg1', 'ProjAlignPreMg2',
]
+_pre_delay_repr = '_*_align_pre_spk_delay_*_'
+
class _AlignPre(DynamicalSystem):
def __init__(self, syn, delay=None):
@@ -36,33 +42,564 @@ def update(self, *args, **kwargs):
self.out.bind_cond(self.syn(*args, **kwargs))
+class _AlignPreMg(DynamicalSystem):
+ def __init__(self, access, syn):
+ super().__init__()
+ self.access = access
+ self.syn = syn
+
+ def update(self):
+ return self.syn(self.access())
+
+
def _init_delay(info: Union[bm.Variable, ReturnInfo]) -> Delay:
if isinstance(info, bm.Variable):
- return VariableDelay(info)
+ return VariDelay(info)
elif isinstance(info, ReturnInfo):
if isinstance(info.batch_or_mode, int):
- size = (info.batch_or_mode,) + tuple(info.size)
+ shape = (info.batch_or_mode,) + tuple(info.size)
batch_axis = 0
elif isinstance(info.batch_or_mode, bm.NonBatchingMode):
- size = tuple(info.size)
+ shape = tuple(info.size)
batch_axis = None
elif isinstance(info.batch_or_mode, bm.BatchingMode):
- size = (info.batch_or_mode.batch_size,) + tuple(info.size)
+ shape = (info.batch_or_mode.batch_size,) + tuple(info.size)
batch_axis = 0
else:
- size = tuple(info.size)
+ shape = tuple(info.size)
batch_axis = None
- target = bm.Variable(info.init(size),
- batch_axis=batch_axis,
- axis_names=info.axis_names)
- return DataDelay(target, target_init=info.init)
+ if isinstance(info.data, Callable):
+ init = info.data(shape)
+ elif isinstance(info.data, (bm.Array, jax.Array)):
+ init = info.data
+ else:
+ raise TypeError
+ assert init.shape == shape
+ if info.axis_names is not None:
+ assert init.ndim == len(info.axis_names)
+ target = bm.Variable(init, batch_axis=batch_axis, axis_names=info.axis_names)
+ return DataDelay(target, data_init=info.data)
else:
raise TypeError
-class ProjAlignPre(Projection):
+def _get_return(return_info):
+ if isinstance(return_info, bm.Variable):
+ return return_info.value
+ elif isinstance(return_info, ReturnInfo):
+ return return_info.get_data()
+ else:
+ raise NotImplementedError
+
+
+class VanillaProj(Projection):
+ """Synaptic projection which defines the synaptic computation with the dimension of pre-synaptic neuron group.
+
+ **Code Examples**
+
+ To simulate an E/I balanced network model:
+
+ .. code-block::
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.delay = bp.VariableDelay(self.N.spike, entries={'I': None})
+ self.syn1 = bp.dyn.Expon(size=3200, tau=5.)
+ self.syn2 = bp.dyn.Expon(size=800, tau=10.)
+ self.E = bp.dyn.VanillaProj(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.N)
+ self.I = bp.dyn.VanillaProj(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.N)
+
+ def update(self, input):
+ spk = self.delay.at('I')
+ self.E(self.syn1(spk[:3200]))
+ self.I(self.syn2(spk[3200:]))
+ self.delay(self.N(input))
+ return self.N.spike.value
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+
+ Args:
+ comm: The synaptic communication.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ comm: DynamicalSystem,
+ out: JointType[DynamicalSystem, BindCondData],
+ post: Dynamic,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(out, JointType[DynamicalSystem, BindCondData])
+ check.is_instance(post, Dynamic)
+ self.post = post
+ self.comm = comm
+
+ # output initialization
+ post.cur_inputs[self.name] = out
+
+ def update(self, x):
+ current = self.comm(x)
+ self.post.cur_inputs[self.name].bind_cond(current)
+ return current
+
+
+class ProjAlignPostMg1(Projection):
+ r"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
+
+ **Code Examples**
+
+ To define an E/I balanced network model.
+
+ .. code-block:: python
+
+ import brainpy as bp
+ import brainpy.math as bm
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.delay = bp.VariableDelay(self.N.spike, entries={'I': None})
+ self.E = bp.dyn.ProjAlignPostMg1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon.desc(size=4000, tau=5.),
+ out=bp.dyn.COBA.desc(E=0.),
+ post=self.N)
+ self.I = bp.dyn.ProjAlignPostMg1(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon.desc(size=4000, tau=10.),
+ out=bp.dyn.COBA.desc(E=-80.),
+ post=self.N)
+
+ def update(self, input):
+ spk = self.delay.at('I')
+ self.E(spk[:3200])
+ self.I(spk[3200:])
+ self.delay(self.N(input))
+ return self.N.spike.value
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+ Args:
+ comm: The synaptic communication.
+ syn: The synaptic dynamics.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ comm: DynamicalSystem,
+ syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]],
+ out: ParamDescInit[JointType[DynamicalSystem, BindCondData]],
+ post: Dynamic,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AlignPost]])
+ check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]])
+ check.is_instance(post, Dynamic)
+ self.post = post
+ self.comm = comm
+
+ # synapse and output initialization
+ self._post_repr = f'{syn._identifier} // {out._identifier}'
+ if self._post_repr not in self.post.before_updates:
+ syn_cls = syn()
+ out_cls = out()
+ self.post.cur_inputs[self.name] = out_cls
+ self.post.before_updates[self._post_repr] = _AlignPost(syn_cls, out_cls)
+
+ def update(self, x):
+ current = self.comm(x)
+ syn: _AlignPost = self.post.before_updates[self._post_repr].syn
+ syn.add_current(current) # synapse post current
+ return current
+
+
+class ProjAlignPostMg2(Projection):
+ """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
+
+ **Code Examples**
+
+ To define an E/I balanced network model.
+
+ .. code-block:: python
+
+ import brainpy as bp
+ import brainpy.math as bm
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ ne, ni = 3200, 800
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.E2E = bp.dyn.ProjAlignPostMg2(pre=self.E,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ out=bp.dyn.COBA.desc(E=0.),
+ post=self.E)
+ self.E2I = bp.dyn.ProjAlignPostMg2(pre=self.E,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon.desc(size=ni, tau=5.),
+ out=bp.dyn.COBA.desc(E=0.),
+ post=self.I)
+ self.I2E = bp.dyn.ProjAlignPostMg2(pre=self.I,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon.desc(size=ne, tau=10.),
+ out=bp.dyn.COBA.desc(E=-80.),
+ post=self.E)
+ self.I2I = bp.dyn.ProjAlignPostMg2(pre=self.I,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ out=bp.dyn.COBA.desc(E=-80.),
+ post=self.I)
+
+ def update(self, inp):
+ self.E2E()
+ self.E2I()
+ self.I2E()
+ self.I2I()
+ self.E(inp)
+ self.I(inp)
+ return self.E.spike
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+ Args:
+ pre: The pre-synaptic neuron group.
+ delay: The synaptic delay.
+ comm: The synaptic communication.
+ syn: The synaptic dynamics.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ pre: JointType[DynamicalSystem, AutoDelaySupp],
+ delay: Union[None, int, float],
+ comm: DynamicalSystem,
+ syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]],
+ out: ParamDescInit[JointType[DynamicalSystem, BindCondData]],
+ post: Dynamic,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp])
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AlignPost]])
+ check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]])
+ check.is_instance(post, Dynamic)
+ self.pre = pre
+ self.post = post
+ self.comm = comm
+
+ # delay initialization
+ if _pre_delay_repr not in self.pre.after_updates:
+ # pre should support "ProjAutoDelay"
+ delay_cls = _init_delay(pre.return_info())
+ # add to "after_updates"
+ self.pre.after_updates[_pre_delay_repr] = delay_cls
+ delay_cls: Delay = pre.after_updates[_pre_delay_repr]
+ delay_cls.register_entry(self.name, delay)
+
+ # synapse and output initialization
+ self._post_repr = f'{syn._identifier} // {out._identifier}'
+ if self._post_repr not in self.post.before_updates:
+ syn_cls = syn()
+ out_cls = out()
+ self.post.cur_inputs[self.name] = out_cls
+ self.post.before_updates[self._post_repr] = _AlignPost(syn_cls, out_cls)
+
+ def update(self):
+ x = self.pre.after_updates[_pre_delay_repr].at(self.name)
+ current = self.comm(x)
+ self.post.before_updates[self._post_repr].syn.add_current(current) # synapse post current
+ return current
+
+
+class ProjAlignPost1(Projection):
+ """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
+
+ To simulate an E/I balanced network:
+
+ .. code-block::
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.delay = bp.VariableDelay(self.N.spike, entries={'I': None})
+ self.E = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon(size=4000, tau=5.),
+ out=bp.dyn.COBA(E=0.),
+ post=self.N)
+ self.I = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon(size=4000, tau=10.),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.N)
+
+ def update(self, input):
+ spk = self.delay.at('I')
+ self.E(spk[:3200])
+ self.I(spk[3200:])
+ self.delay(self.N(input))
+ return self.N.spike.value
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+
+ Args:
+ comm: The synaptic communication.
+ syn: The synaptic dynamics.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ comm: DynamicalSystem,
+ syn: JointType[DynamicalSystem, AlignPost],
+ out: JointType[DynamicalSystem, BindCondData],
+ post: Dynamic,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(syn, JointType[DynamicalSystem, AlignPost])
+ check.is_instance(out, JointType[DynamicalSystem, BindCondData])
+ check.is_instance(post, Dynamic)
+ self.post = post
+ self.comm = comm
+
+ # synapse and output initialization
+ self.post.cur_inputs[self.name] = out
+ self.post.before_updates[self.name] = _AlignPost(syn, out)
+
+ def update(self, x):
+ current = self.comm(x)
+ syn: _AlignPost = self.post.before_updates[self.name].syn
+ syn.add_current(current) # synapse post current
+ return current
+
+
+class ProjAlignPost2(Projection):
+ """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
+
+ To simulate and define an E/I balanced network model:
+
+ .. code-block:: python
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ ne, ni = 3200, 800
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.E2E = bp.dyn.ProjAlignPost2(pre=self.E,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon(size=ne, tau=5.),
+ out=bp.dyn.COBA(E=0.),
+ post=self.E)
+ self.E2I = bp.dyn.ProjAlignPost2(pre=self.E,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon(size=ni, tau=5.),
+ out=bp.dyn.COBA(E=0.),
+ post=self.I)
+ self.I2E = bp.dyn.ProjAlignPost2(pre=self.I,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon(size=ne, tau=10.),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.E)
+ self.I2I = bp.dyn.ProjAlignPost2(pre=self.I,
+ delay=0.1,
+ comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon(size=ni, tau=10.),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.I)
+
+ def update(self, inp):
+ self.E2E()
+ self.E2I()
+ self.I2E()
+ self.I2I()
+ self.E(inp)
+ self.I(inp)
+ return self.E.spike
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+
+ Args:
+ pre: The pre-synaptic neuron group.
+ delay: The synaptic delay.
+ comm: The synaptic communication.
+ syn: The synaptic dynamics.
+ out: The synaptic output.
+ post: The post-synaptic neuron group.
+ name: str. The projection name.
+ mode: Mode. The computing mode.
+ """
+
+ def __init__(
+ self,
+ pre: JointType[DynamicalSystem, AutoDelaySupp],
+ delay: Union[None, int, float],
+ comm: DynamicalSystem,
+ syn: JointType[DynamicalSystem, AlignPost],
+ out: JointType[DynamicalSystem, BindCondData],
+ post: Dynamic,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp])
+ check.is_instance(comm, DynamicalSystem)
+ check.is_instance(syn, JointType[DynamicalSystem, AlignPost])
+ check.is_instance(out, JointType[DynamicalSystem, BindCondData])
+ check.is_instance(post, Dynamic)
+ self.pre = pre
+ self.post = post
+ self.comm = comm
+
+ # delay initialization
+ if _pre_delay_repr not in self.pre.after_updates:
+ # pre should support "ProjAutoDelay"
+ delay_cls = _init_delay(pre.return_info())
+ # add to "after_updates"
+ self.pre.after_updates[_pre_delay_repr] = delay_cls
+ delay_cls: Delay = pre.after_updates[_pre_delay_repr]
+ delay_cls.register_entry(self.name, delay)
+
+ # synapse and output initialization
+ self.post.cur_inputs[self.name] = out
+ self.post.before_updates[self.name] = _AlignPost(syn, out)
+
+ def update(self):
+ x = self.pre.after_updates[_pre_delay_repr].at(self.name)
+ current = self.comm(x)
+ self.post.before_updates[self.name].syn.add_current(current) # synapse post current
+ return current
+
+
+class ProjAlignPreMg1(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
+ To simulate an E/I balanced network model:
+
+ .. code-block:: python
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ ne, ni = 3200, 800
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ delay=0.1,
+ comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.E)
+ self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ delay=0.1,
+ comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.I)
+ self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ delay=0.1,
+ comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.E)
+ self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ delay=0.1,
+ comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.I)
+
+ def update(self, inp):
+ self.E2E()
+ self.E2I()
+ self.I2E()
+ self.I2I()
+ self.E(inp)
+ self.I(inp)
+ return self.E.spike
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
+
Args:
pre: The pre-synaptic neuron group.
syn: The synaptic dynamics.
@@ -88,11 +625,11 @@ def __init__(
super().__init__(name=name, mode=mode)
# synaptic models
- assert isinstance(pre, DynamicalSystem)
- assert isinstance(syn, ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]])
- assert isinstance(comm, Callable)
- assert isinstance(out, JointType[DynamicalSystem, BindCondData])
- assert isinstance(post, Dynamic)
+ check.is_instance(pre, DynamicalSystem)
+ check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]])
+ check.is_instance(comm, Callable)
+ check.is_instance(out, JointType[DynamicalSystem, BindCondData])
+ check.is_instance(post, Dynamic)
self.pre = pre
self.post = post
self.comm = comm
@@ -119,27 +656,79 @@ def update(self, x=None):
return current
-class ProjAlignPost(Projection):
- """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
+class ProjAlignPreMg2(Projection):
+ """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
+
+ To simulate an E/I balanced network model:
+
+ .. code-block:: python
+
+ class EINet(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ ne, ni = 3200, 800
+ self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.))
+ self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E,
+ delay=0.1,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.E)
+ self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E,
+ delay=0.1,
+ syn=bp.dyn.Expon.desc(size=ne, tau=5.),
+ comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
+ out=bp.dyn.COBA(E=0.),
+ post=self.I)
+ self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I,
+ delay=0.1,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.E)
+ self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I,
+ delay=0.1,
+ syn=bp.dyn.Expon.desc(size=ni, tau=10.),
+ comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
+ out=bp.dyn.COBA(E=-80.),
+ post=self.I)
+
+ def update(self, inp):
+ self.E2E()
+ self.E2I()
+ self.I2E()
+ self.I2I()
+ self.E(inp)
+ self.I(inp)
+ return self.E.spike
+
+ model = EINet()
+ indices = bm.arange(1000)
+ spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
+ bp.visualize.raster_plot(indices, spks, show=True)
+
Args:
pre: The pre-synaptic neuron group.
delay: The synaptic delay.
- comm: The synaptic communication.
syn: The synaptic dynamics.
+ comm: The synaptic communication.
out: The synaptic output.
post: The post-synaptic neuron group.
name: str. The projection name.
- mode: Mode. The computing mode.
+ mode: Mode. The computing mode.
"""
def __init__(
self,
pre: JointType[DynamicalSystem, AutoDelaySupp],
delay: Union[None, int, float],
+ syn: ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]],
comm: Callable,
- syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]],
- out: ParamDescInit[JointType[DynamicalSystem, BindCondData]],
+ out: JointType[DynamicalSystem, BindCondData],
post: Dynamic,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
@@ -147,36 +736,37 @@ def __init__(
super().__init__(name=name, mode=mode)
# synaptic models
- assert isinstance(pre, JointType[DynamicalSystem, AutoDelaySupp])
- assert isinstance(comm, Callable)
- assert isinstance(syn, ParamDescInit[JointType[DynamicalSystem, AlignPost]])
- assert isinstance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]])
- assert isinstance(post, Dynamic)
+ check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp])
+ check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]])
+ check.is_instance(comm, Callable)
+ check.is_instance(out, JointType[DynamicalSystem, BindCondData])
+ check.is_instance(post, Dynamic)
self.pre = pre
self.post = post
self.comm = comm
- # delay initialization
- self._delay_repr = '_*_align_pre_spk_delay_*_'
- if self._delay_repr not in self.pre.after_updates:
- # pre should support "ProjAutoDelay"
+ # synapse and delay initialization
+ if _pre_delay_repr not in self.pre.after_updates:
delay_cls = _init_delay(pre.return_info())
- # add to "after_updates"
- self.pre.after_updates[self._delay_repr] = delay_cls
- delay_cls: Delay = pre.after_updates[self._delay_repr]
- delay_cls.register_entry(self.name, delay)
+ self.pre.after_updates[_pre_delay_repr] = delay_cls
- # synapse and output initialization
- self._post_repr = f'{syn._identifier} // {out._identifier}'
- if self._post_repr not in self.post.before_updates:
+ # synapse
+ self._syn_id = f'{str(delay)} / {syn.identifier}'
+ if self._syn_id not in post.before_updates:
+ # delay
+ delay_cls: Delay = pre.after_updates[_pre_delay_repr]
+ delay_access = DelayAccess(delay_cls, delay)
+ # synapse
syn_cls = syn()
- out_cls = out()
- self.post.cur_inputs[self.name] = out_cls
- self.post.before_updates[self._post_repr] = _AlignPost(syn_cls, out_cls)
+ # add to "after_updates"
+ post.before_updates[self._syn_id] = _AlignPreMg(delay_access, syn_cls)
- def update(self, x=None):
- if x is None:
- x = self.pre.after_updates[self._delay_repr].at(self.name)
+ # output initialization
+ post.cur_inputs[self.name] = out
+
+ def update(self):
+ x = self.post.before_updates[self._syn_id].syn.return_info()
+ x = _get_return(x)
current = self.comm(x)
- self.post.before_updates[self._post_repr].syn.add_current(current) # synapse post current
+ self.post.cur_inputs[self.name].bind_cond(current)
return current
diff --git a/brainpy/dyn/projections.py b/brainpy/dyn/projections.py
index a09617988..0ec6b26ad 100644
--- a/brainpy/dyn/projections.py
+++ b/brainpy/dyn/projections.py
@@ -1,8 +1,13 @@
from brainpy._src.dyn.projections.aligns import (
- ProjAlignPost as ProjAlignPost,
- ProjAlignPre as ProjAlignPre,
+ VanillaProj,
+ ProjAlignPostMg1,
+ ProjAlignPostMg2,
+ ProjAlignPost1,
+ ProjAlignPost2,
+ ProjAlignPreMg1,
+ ProjAlignPreMg2,
)
from brainpy._src.dyn.projections.conn import (
From 5605ae85f3e928914c9e2973b93637c90470c05c Mon Sep 17 00:00:00 2001
From: chaoming
Date: Wed, 19 Jul 2023 22:51:38 +0800
Subject: [PATCH 049/326] reformat old version synapses
---
.../_src/dynold/synapses/abstract_models.py | 114 ++++++++++-----
brainpy/_src/dynold/synapses/base.py | 131 ++++++++++--------
.../dynold/synplast/short_term_plasticity.py | 9 +-
3 files changed, 158 insertions(+), 96 deletions(-)
diff --git a/brainpy/_src/dynold/synapses/abstract_models.py b/brainpy/_src/dynold/synapses/abstract_models.py
index bc50f8c4c..114b74468 100644
--- a/brainpy/_src/dynold/synapses/abstract_models.py
+++ b/brainpy/_src/dynold/synapses/abstract_models.py
@@ -6,13 +6,14 @@
import brainpy.math as bm
from brainpy._src.connect import TwoEndConnector, All2All, One2One
+from brainpy._src.context import share
from brainpy._src.dyn import synapses
-from brainpy._src.dynold.synouts import MgBlock, CUBA
from brainpy._src.dyn.base import NeuDyn
-from brainpy._src.initialize import Initializer
-from brainpy._src.mixin import AlignPost
+from brainpy._src.dynold.synouts import MgBlock, CUBA
+from brainpy._src.initialize import Initializer, variable_
+from brainpy._src.integrators.ode.generic import odeint
from brainpy.types import ArrayType
-from .base import TwoEndConn, _SynSTP, _SynOut, _TwoEndConnAlignPre, _TwoEndConnAlignPost, _DelayedSyn, _init_stp
+from .base import TwoEndConn, _SynSTP, _SynOut, _TwoEndConnAlignPre, _DelayedSyn, _init_stp
__all__ = [
'Delta',
@@ -175,7 +176,7 @@ def update(self, pre_spike=None):
return self.output(post_vs)
-class Exponential(_TwoEndConnAlignPost, AlignPost):
+class Exponential(TwoEndConn):
r"""Exponential decay synapse model.
**Model Descriptions**
@@ -201,10 +202,10 @@ class Exponential(_TwoEndConnAlignPost, AlignPost):
& g_{\mathrm{syn}}(t) = g_{max} g * \mathrm{STP} \\
& \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}).
\end{aligned}
-
+
where :math:`\mathrm{STP}` is used to model the short-term plasticity effect.
-
-
+
+
**Model Examples**
- `(Brunel & Hakim, 1999) Fast Global Oscillation `_
@@ -241,9 +242,9 @@ class Exponential(_TwoEndConnAlignPost, AlignPost):
Parameters
----------
- pre: NeuDyn
+ pre: NeuGroup
The pre-synaptic neuron group.
- post: NeuDyn
+ post: NeuGroup
The post-synaptic neuron group.
conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
@@ -282,10 +283,19 @@ def __init__(
delay_step: Union[int, ArrayType, Initializer, Callable] = None,
tau: Union[float, ArrayType] = 8.0,
method: str = 'exp_auto',
- name: Optional[str] = None,
- mode: Optional[bm.Mode] = None,
+
+ # other parameters
+ name: str = None,
+ mode: bm.Mode = None,
stop_spike_gradient: bool = False,
):
+ super().__init__(pre=pre,
+ post=post,
+ conn=conn,
+ output=output,
+ stp=stp,
+ name=name,
+ mode=mode)
# parameters
self.stop_spike_gradient = stop_spike_gradient
self.comp_method = comp_method
@@ -293,37 +303,71 @@ def __init__(
if bm.size(self.tau) != 1:
raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}')
- syn = synapses.Expon.desc(post.size,
- post.keep_size,
- mode=mode,
- tau=tau,
- method=method)
+ # connections and weights
+ self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='csr')
- super().__init__(pre=pre,
- post=post,
- syn=syn,
- conn=conn,
- output=output,
- stp=stp,
- comp_method=comp_method,
- g_max=g_max,
- delay_step=delay_step,
- name=name,
- mode=mode)
+ # variables
+ self.g = variable_(bm.zeros, self.post.num, self.mode)
+ self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
- # copy the references
- syn = self.post.before_updates[self.proj._post_repr].syn
- self.g = syn.g
+ # function
+ self.integral = odeint(lambda g, t: -g / self.tau, method=method)
+
+ def reset_state(self, batch_size=None):
+ self.g.value = variable_(bm.zeros, self.post.num, batch_size)
+ self.output.reset_state(batch_size)
+ if self.stp is not None: self.stp.reset_state(batch_size)
def update(self, pre_spike=None):
- return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient)
+ t, dt = share['t'], share['dt']
+
+ # delays
+ if pre_spike is None:
+ pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
+ pre_spike = bm.as_jax(pre_spike)
+ if self.stop_spike_gradient:
+ pre_spike = jax.lax.stop_gradient(pre_spike)
+
+ # update sub-components
+ self.output.update()
+ if self.stp is not None:
+ self.stp.update(pre_spike)
- def add_current(self, input):
- self.g += input
+ # post values
+ if isinstance(self.conn, All2All):
+ syn_value = bm.asarray(pre_spike, dtype=bm.float_)
+ if self.stp is not None: syn_value = self.stp(syn_value)
+ post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
+ elif isinstance(self.conn, One2One):
+ syn_value = bm.asarray(pre_spike, dtype=bm.float_)
+ if self.stp is not None: syn_value = self.stp(syn_value)
+ post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
+ else:
+ if self.comp_method == 'sparse':
+ f = lambda s: bm.event.csrmv(self.g_max,
+ self.conn_mask[0],
+ self.conn_mask[1],
+ s,
+ shape=(self.pre.num, self.post.num),
+ transpose=True)
+ if isinstance(self.mode, bm.BatchingMode): f = jax.vmap(f)
+ post_vs = f(pre_spike)
+ # if not isinstance(self.stp, _NullSynSTP):
+ # raise NotImplementedError()
+ else:
+ syn_value = bm.asarray(pre_spike, dtype=bm.float_)
+ if self.stp is not None:
+ syn_value = self.stp(syn_value)
+ post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
+ # updates
+ self.g.value = self.integral(self.g.value, t, dt) + post_vs
+
+ # output
+ return self.output(self.g)
class _DelayedDualExp(_DelayedSyn):
- not_desc_params = ('master', 'stp', 'mode')
+ not_desc_params = ('master', 'mode')
def __init__(self, size, keep_size, mode, tau_decay, tau_rise, method, master, stp=None):
syn = synapses.DualExpon(size,
diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py
index ac84ed797..f3fcda4c3 100644
--- a/brainpy/_src/dynold/synapses/base.py
+++ b/brainpy/_src/dynold/synapses/base.py
@@ -7,6 +7,7 @@
from brainpy._src.dnn import linear
from brainpy._src.dyn import projections
from brainpy._src.dyn.base import NeuDyn
+from brainpy._src.dyn.projections.aligns import _pre_delay_repr
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.initialize import parameter
from brainpy._src.mixin import (ParamDesc, ParamDescInit, JointType,
@@ -24,7 +25,6 @@
]
-
class _SynapseComponent(DynamicalSystem):
"""Base class for modeling synaptic components,
including synaptic output, synaptic short-term plasticity,
@@ -119,7 +119,7 @@ def update(self, pre_spike):
def return_info(self):
assert self.isregistered
- return ReturnInfo(self.master.pre.varshape, None, self.master.pre.mode, init=bm.zeros)
+ return ReturnInfo(self.master.pre.varshape, None, self.master.pre.mode, bm.zeros)
class _NullSynOut(_SynOut):
@@ -316,38 +316,38 @@ def __init__(
# Projection
if isinstance(conn, All2All):
- proj = projections.ProjAlignPre(pre=pre,
- syn=syn,
- delay=delay,
- comm=linear.AllToAll(pre.num, post.num, g_max),
- out=_TempOut(),
- post=post)
+ proj = projections.ProjAlignPreMg1(pre=pre,
+ syn=syn,
+ delay=delay,
+ comm=linear.AllToAll(pre.num, post.num, g_max),
+ out=_TempOut(),
+ post=post)
elif isinstance(conn, One2One):
assert post.num == pre.num
- proj = projections.ProjAlignPre(pre=pre,
- syn=syn,
- delay=delay,
- comm=linear.OneToOne(pre.num, g_max),
- out=_TempOut(),
- post=post)
+ proj = projections.ProjAlignPreMg1(pre=pre,
+ syn=syn,
+ delay=delay,
+ comm=linear.OneToOne(pre.num, g_max),
+ out=_TempOut(),
+ post=post)
else:
if comp_method == 'dense':
- proj = projections.ProjAlignPre(pre=pre,
- syn=syn,
- delay=delay,
- comm=linear.MaskedLinear(conn, g_max),
- out=_TempOut(),
- post=post)
+ proj = projections.ProjAlignPreMg1(pre=pre,
+ syn=syn,
+ delay=delay,
+ comm=linear.MaskedLinear(conn, g_max),
+ out=_TempOut(),
+ post=post)
elif comp_method == 'sparse':
- proj = projections.ProjAlignPre(pre=pre,
- syn=syn,
- delay=delay,
- comm=linear.CSRLinear(conn, g_max),
- out=_TempOut(),
- post=post)
+ proj = projections.ProjAlignPreMg1(pre=pre,
+ syn=syn,
+ delay=delay,
+ comm=linear.CSRLinear(conn, g_max),
+ out=_TempOut(),
+ post=post)
else:
raise UnsupportedError(f'Does not support {comp_method}, only "sparse" or "dense".')
@@ -365,12 +365,22 @@ def update(self, pre_spike=None, stop_spike_gradient: bool = False):
return self.output(current)
+class _UpdateSTP(DynamicalSystem):
+ def __init__(self, stp):
+ super().__init__()
+ self.stp = stp
+
+ def update(self, x):
+ self.stp.update(x)
+ return self.stp(x)
+
+
class _TwoEndConnAlignPost(TwoEndConn):
def __init__(
self,
pre: NeuDyn,
post: NeuDyn,
- syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]],
+ syn: JointType[DynamicalSystem, AlignPost],
conn: TwoEndConnector,
g_max: Union[float, ArrayType, Callable],
output: _SynOut = _NullSynOut(),
@@ -389,50 +399,52 @@ def __init__(
mode=mode,
init_stp=True)
- pre = _DelayedSyn(pre, self.stp)
delay = _get_delay(delay_step)
-
- # make every synapse unique
- syn._identifier = syn._identifier + f' // {self.name}'
+ if self.stp is None:
+ pre = pre
+ else:
+ stp = _UpdateSTP(self.stp)
+ pre.after_updates[self.name] = stp
+ pre = stp
# Projection
if isinstance(conn, All2All):
- proj = projections.ProjAlignPost(pre=pre,
- delay=delay,
- comm=linear.AllToAll(self.pre.num, self.post.num, g_max),
- syn=syn,
- out=_TempOut.desc(),
- post=post)
+ proj = projections.ProjAlignPost2(pre=pre,
+ delay=delay,
+ comm=linear.AllToAll(self.pre.num, self.post.num, g_max),
+ syn=syn,
+ out=_TempOut(),
+ post=post)
elif isinstance(conn, One2One):
assert post.num == self.pre.num
- proj = projections.ProjAlignPost(pre=pre,
- delay=delay,
- comm=linear.OneToOne(self.pre.num, g_max),
- syn=syn,
- out=_TempOut.desc(),
- post=post)
+ proj = projections.ProjAlignPost2(pre=pre,
+ delay=delay,
+ comm=linear.OneToOne(self.pre.num, g_max),
+ syn=syn,
+ out=_TempOut(),
+ post=post)
else:
if comp_method == 'dense':
- proj = projections.ProjAlignPost(pre=pre,
- delay=delay,
- comm=linear.MaskedLinear(conn, g_max),
- syn=syn,
- out=_TempOut.desc(),
- post=post)
+ proj = projections.ProjAlignPost2(pre=pre,
+ delay=delay,
+ comm=linear.MaskedLinear(self.conn, g_max),
+ syn=syn,
+ out=_TempOut(),
+ post=post)
elif comp_method == 'sparse':
if self.stp is None:
- comm = linear.EventCSRLinear(conn, g_max)
+ comm = linear.EventCSRLinear(self.conn, g_max)
else:
- comm = linear.CSRLinear(conn, g_max)
- proj = projections.ProjAlignPost(pre=pre,
- delay=delay,
- comm=comm,
- syn=syn,
- out=_TempOut.desc(),
- post=post)
+ comm = linear.CSRLinear(self.conn, g_max)
+ proj = projections.ProjAlignPost2(pre=pre,
+ delay=delay,
+ comm=comm,
+ syn=syn,
+ out=_TempOut(),
+ post=post)
else:
raise UnsupportedError(f'Does not support {comp_method}, only "sparse" or "dense".')
@@ -441,12 +453,12 @@ def __init__(
def update(self, pre_spike=None, stop_spike_gradient: bool = False):
if pre_spike is None:
- pre_spike = self.proj.pre.after_updates[self.proj._delay_repr].at(self.proj.name)
+ pre_spike = self.proj.pre.after_updates[_pre_delay_repr].at(self.proj.name)
if stop_spike_gradient:
# TODO: if self.stp is not None
pre_spike = jax.lax.stop_gradient(pre_spike)
current = self.proj.comm(pre_spike)
- self.proj.post.before_updates[self.proj._post_repr].syn.add_current(current) # synapse post current
+ self.proj.post.before_updates[self.proj.name].syn.add_current(current) # synapse post current
return self.output(current)
@@ -468,4 +480,3 @@ def return_info(self):
return self.syn.return_info()
else:
return self.stp.return_info()
-
diff --git a/brainpy/_src/dynold/synplast/short_term_plasticity.py b/brainpy/_src/dynold/synplast/short_term_plasticity.py
index da3428662..b19825e64 100644
--- a/brainpy/_src/dynold/synplast/short_term_plasticity.py
+++ b/brainpy/_src/dynold/synplast/short_term_plasticity.py
@@ -58,7 +58,7 @@ def __init__(
method: str = 'exp_auto',
name: str = None
):
- super(STD, self).__init__(name=name)
+ super().__init__(name=name)
# parameters
is_float(tau, 'tau', min_bound=0, )
@@ -89,6 +89,9 @@ def filter(self, g):
raise ValueError('Shape does not match.')
return g * self.x
+ def __repr__(self):
+ return f'{self.__class__.__name__}(tau={self.tau}, U={self.U}, method={self.method})'
+
class STP(_SynSTP):
r"""Synaptic output with short-term plasticity.
@@ -184,3 +187,7 @@ def filter(self, g):
if jnp.shape(g) != self.x.shape:
raise ValueError('Shape does not match.')
return g * self.x * self.u
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(tau_f={self.tau_f}, tau_d={self.tau_d}, U={self.U}, method={self.method})'
+
From 1c3d8014007b4627f952ea879df76d99c4f95b58 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Wed, 19 Jul 2023 22:53:14 +0800
Subject: [PATCH 050/326] add `brainpy.dyn.MixIons` for modeling cross ion
channels
---
brainpy/_src/dyn/channels/potassium_compatible.py | 15 ++++++++-------
brainpy/_src/dyn/channels/sodium_compatible.py | 12 ++++++------
brainpy/_src/dyn/ions/base.py | 8 +++++---
brainpy/_src/dyn/ions/calcium.py | 4 ++--
brainpy/_src/dyn/ions/tests/test_MixIons.py | 6 +++---
brainpy/_src/dyn/neurons/hh.py | 1 -
6 files changed, 24 insertions(+), 22 deletions(-)
diff --git a/brainpy/_src/dyn/channels/potassium_compatible.py b/brainpy/_src/dyn/channels/potassium_compatible.py
index d9bb41b61..2bb4468ed 100644
--- a/brainpy/_src/dyn/channels/potassium_compatible.py
+++ b/brainpy/_src/dyn/channels/potassium_compatible.py
@@ -9,12 +9,11 @@
import brainpy.math as bm
from brainpy._src.context import share
-from brainpy._src.dyn.channels.leaky import LeakyChannel
+from brainpy._src.dyn.channels.base import IonChannel
from brainpy._src.dyn.neurons.hh import HHTypedNeuron
from brainpy._src.initialize import Initializer, parameter, variable
from brainpy._src.integrators import odeint, JointEq
from brainpy.types import ArrayType
-from .potassium import PotassiumChannel
__all__ = [
'IKDR_Ba2002',
@@ -29,7 +28,7 @@
]
-class _IK_p4_markov(PotassiumChannel):
+class _IK_p4_markov(IonChannel):
r"""The delayed rectifier potassium channel of :math:`p^4`
current which described with first-order Markov chain.
@@ -339,7 +338,7 @@ def f_p_beta(self, V):
return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80)
-class _IKA_p4q_ss(PotassiumChannel):
+class _IKA_p4q_ss(IonChannel):
r"""The rapidly inactivating Potassium channel of :math:`p^4q`
current which described with steady-state format.
@@ -634,7 +633,7 @@ def f_q_tau(self, V):
19.)
-class _IKK2_pq_ss(PotassiumChannel):
+class _IKK2_pq_ss(IonChannel):
r"""The slowly inactivating Potassium channel of :math:`pq`
current which described with steady-state format.
@@ -921,7 +920,7 @@ def f_q_tau(self, V):
8.9)
-class IKNI_Ya1989(PotassiumChannel):
+class IKNI_Ya1989(IonChannel):
r"""A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_.
This slow potassium current can effectively account for spike-frequency adaptation.
@@ -1019,7 +1018,7 @@ def f_p_tau(self, V):
return self.tau_max / (3.3 * bm.exp(temp / 20.) + bm.exp(-temp / 20.))
-class IKL(LeakyChannel):
+class IKL(IonChannel):
"""The potassium leak channel current.
Parameters
@@ -1031,6 +1030,8 @@ class IKL(LeakyChannel):
The reversal potential.
"""
+ master_type = HHTypedNeuron
+
def __init__(
self,
size: Union[int, Sequence[int]],
diff --git a/brainpy/_src/dyn/channels/sodium_compatible.py b/brainpy/_src/dyn/channels/sodium_compatible.py
index 9a05593b0..ec60eb1c9 100644
--- a/brainpy/_src/dyn/channels/sodium_compatible.py
+++ b/brainpy/_src/dyn/channels/sodium_compatible.py
@@ -13,7 +13,7 @@
from brainpy._src.initialize import Initializer, parameter, variable
from brainpy._src.integrators import odeint, JointEq
from brainpy.types import ArrayType
-from .sodium import SodiumChannel
+from .base import IonChannel
__all__ = [
'INa_Ba2002',
@@ -22,7 +22,7 @@
]
-class _INa_p3q_markov(SodiumChannel):
+class _INa_p3q_markov(IonChannel):
r"""The sodium current model of :math:`p^3q` current which described with first-order Markov chain.
The general model can be used to model the dynamics with:
@@ -64,7 +64,7 @@ def __init__(
name: str = None,
mode: bm.Mode = None,
):
- super(_INa_p3q_markov, self).__init__(size=size,
+ super().__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode)
@@ -173,7 +173,7 @@ def __init__(
name: str = None,
mode: bm.Mode = None,
):
- super(INa_Ba2002, self).__init__(size,
+ super().__init__(size,
keep_size=keep_size,
name=name,
method=method,
@@ -260,7 +260,7 @@ def __init__(
name: str = None,
mode: bm.Mode = None,
):
- super(INa_TM1991, self).__init__(size,
+ super().__init__(size,
keep_size=keep_size,
name=name,
method=method,
@@ -347,7 +347,7 @@ def __init__(
name: str = None,
mode: bm.Mode = None,
):
- super(INa_HH1952, self).__init__(size,
+ super().__init__(size,
keep_size=keep_size,
name=name,
method=method,
diff --git a/brainpy/_src/dyn/ions/base.py b/brainpy/_src/dyn/ions/base.py
index 804e551bc..175b9413e 100644
--- a/brainpy/_src/dyn/ions/base.py
+++ b/brainpy/_src/dyn/ions/base.py
@@ -166,13 +166,14 @@ def update(self, V):
for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values():
node.update(V, self.C, self.E)
- def current(self, V, C=None, E=None):
+ def current(self, V, C=None, E=None, external: bool = False):
"""Generate ion channel current.
Args:
V: The membrane potential.
C: The ion concentration.
E: The reversal potential.
+ external: Include the external current.
Returns:
Current.
@@ -186,8 +187,9 @@ def current(self, V, C=None, E=None):
if len(nodes) > 0:
for node in nodes:
current = current + node.current(V, C, E)
- for key, node in self.external.items():
- current = current + node(V, C, E)
+ if external:
+ for key, node in self.external.items():
+ current = current + node(V, C, E)
return current
def reset_state(self, V, batch_size=None):
diff --git a/brainpy/_src/dyn/ions/calcium.py b/brainpy/_src/dyn/ions/calcium.py
index 4fa50daed..49e8fa18c 100644
--- a/brainpy/_src/dyn/ions/calcium.py
+++ b/brainpy/_src/dyn/ions/calcium.py
@@ -273,7 +273,7 @@ def __init__(
self.C_rest = parameter(C_rest, self.varshape, allow_none=False)
def derivative(self, C, t, V):
- ICa = self.current(V, C, self.E)
+ ICa = self.current(V, C, self.E, external=True)
drive = bm.maximum(- ICa / (2 * self.F * self.d), 0.)
return drive + (self.C_rest - C) / self.tau
@@ -316,6 +316,6 @@ def __init__(
self.beta = parameter(beta, self.varshape, allow_none=False)
def derivative(self, C, t, V):
- ICa = self.current(V, C, self.E)
+ ICa = self.current(V, C, self.E, external=True)
drive = bm.maximum(- self.alpha * ICa, 0.)
return drive - self.beta * C
diff --git a/brainpy/_src/dyn/ions/tests/test_MixIons.py b/brainpy/_src/dyn/ions/tests/test_MixIons.py
index b2731968e..e196ca4d4 100644
--- a/brainpy/_src/dyn/ions/tests/test_MixIons.py
+++ b/brainpy/_src/dyn/ions/tests/test_MixIons.py
@@ -85,9 +85,9 @@ def __init__(self, size):
hh.reset_state()
- ICa = hh.ca.current(hh.V)
- INa = hh.na.current(hh.V)
- IK = hh.k.current(hh.V)
+ ICa = hh.ca.current(hh.V, external=True)
+ INa = hh.na.current(hh.V, external=True)
+ IK = hh.k.current(hh.V, external=True)
print(ICa, INa, IK)
self.assertTrue(bm.allclose(INa, 0.))
diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py
index 4f6e68d34..8440766f3 100644
--- a/brainpy/_src/dyn/neurons/hh.py
+++ b/brainpy/_src/dyn/neurons/hh.py
@@ -4,7 +4,6 @@
import brainpy.math as bm
from brainpy._src.context import share
-from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.dyn.base import NeuDyn, IonChaDyn
from brainpy._src.initialize import OneInit
from brainpy._src.initialize import Uniform, variable_, noise as init_noise
From af476ba8bc2a30c5e4848859d7f9d4c88e00e455 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Wed, 19 Jul 2023 22:54:13 +0800
Subject: [PATCH 051/326] add `step_run` for convenient simulation of any
`DynamicalSystem`
---
brainpy/_src/dynsys.py | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py
index 861b679a0..02624815a 100644
--- a/brainpy/_src/dynsys.py
+++ b/brainpy/_src/dynsys.py
@@ -159,6 +159,15 @@ def clear_input(self):
"""Clear the input at the current time step."""
pass
+ def step_run(self, i, *args, **kwargs):
+ global share
+ if share is None:
+ from brainpy._src.context import share
+ share.save(i=i, t=i * bm.dt)
+ return self.update(*args, **kwargs)
+
+ jit_step_run = bm.cls_jit(step_run, inline=True)
+
@property
def mode(self) -> bm.Mode:
"""Mode of the model, which is useful to control the multiple behaviors of the model."""
From d9a737b8eb24b926024cab153c7136917b5727a0 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Wed, 19 Jul 2023 22:54:35 +0800
Subject: [PATCH 052/326] update brainpy package
---
brainpy/__init__.py | 4 +-
brainpy/_src/delay.py | 49 +++++--
brainpy/_src/dnn/linear.py | 29 ++--
brainpy/_src/dyn/others/input.py | 40 +++---
brainpy/_src/dyn/synapses/abstract_models.py | 24 ++--
brainpy/_src/math/object_transform/jit.py | 5 +-
.../_src/math/object_transform/variables.py | 2 +-
brainpy/_src/mixin.py | 56 ++++++--
brainpy/check.py | 12 +-
brainpy/dyn/__init__.py | 1 +
brainpy/dyn/compat.py | 10 ++
docs/conf.py | 6 +-
examples/dynamics_simulation/COBA-v2.py | 28 ++--
examples/dynamics_simulation/COBA.py | 129 ------------------
14 files changed, 163 insertions(+), 232 deletions(-)
create mode 100644 brainpy/dyn/compat.py
delete mode 100644 examples/dynamics_simulation/COBA.py
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index 4b2f24822..77302e150 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -58,14 +58,14 @@
DynamicalSystem as DynamicalSystem,
DynSysGroup as DynSysGroup, # collectors
Sequential as Sequential,
- Network as Network,
Dynamic as Dynamic, # category
Projection as Projection,
)
DynamicalSystemNS = DynamicalSystem
+Network = DynSysGroup
# delays
from brainpy._src.delay import (
- VariableDelay as VariableDelay,
+ VariDelay as VariDelay,
)
# building blocks
diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py
index d24248d8c..bac40e53f 100644
--- a/brainpy/_src/delay.py
+++ b/brainpy/_src/delay.py
@@ -21,8 +21,9 @@
__all__ = [
'Delay',
- 'VariableDelay',
+ 'VariDelay',
'DataDelay',
+ 'DelayAccess',
]
@@ -431,8 +432,8 @@ def _check_target_sharding(sharding, ndim, mode: bm.Mode):
return sharding
-class VariableDelay(Delay):
- """Delay variable which has a fixed delay length.
+class VariDelay(Delay):
+ """Generate Delays for the given :py:class:`~.Variable` instance.
The data in this delay variable is arranged as::
@@ -517,8 +518,8 @@ def __init__(
# other info
if entries is not None:
- for entry, value in entries.items():
- self.register_entry(entry, value)
+ for entry, delay_time in entries.items():
+ self.register_entry(entry, delay_time)
def register_entry(
self,
@@ -572,11 +573,17 @@ def at(self, entry: str, *indices) -> bm.Array:
raise KeyError(f'Does not find delay entry "{entry}".')
delay_step = self._registered_entries[entry]
if delay_step is None or delay_step == 0.:
- return self.target.value
+ if len(indices):
+ return self.target[indices]
+ else:
+ return self.target.value
else:
assert self.data is not None
if delay_step == 0:
- return self.target.value
+ if len(indices):
+ return self.target[indices]
+ else:
+ return self.target.value
else:
return self.retrieve(delay_step, *indices)
@@ -683,16 +690,15 @@ def _init_data(self, length: int, batch_size: int = None):
self.data[:] = self._init((length,) + self.target.shape, dtype=self.target.dtype)
-class DataDelay(VariableDelay):
-
+class DataDelay(VariDelay):
not_desc_params = ('time', 'entries')
def __init__(
self,
# delay target
- target: bm.Variable,
- target_init: Callable,
+ data: bm.Variable,
+ data_init: Union[Callable, bm.Array, jax.Array],
# delay time
time: Optional[Union[int, float]] = None,
@@ -710,8 +716,8 @@ def __init__(
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
- self.target_init = target_init
- super().__init__(target=target,
+ self.target_init = data_init
+ super().__init__(target=data,
time=time,
init=init,
entries=entries,
@@ -736,3 +742,20 @@ def update(
super().update(latest_value)
+class DelayAccess(DynamicalSystem):
+ def __init__(
+ self,
+ delay: Delay,
+ time: Union[None, int, float],
+ *indices
+ ):
+ super().__init__(mode=delay.mode)
+ self.delay = delay
+ assert isinstance(delay, Delay)
+ delay.register_entry(self.name, time)
+ self.indices = indices
+
+ def update(self):
+ return self.delay.at(self.name, *self.indices)
+
+
diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py
index ef7cc377f..3bdc3a31c 100644
--- a/brainpy/_src/dnn/linear.py
+++ b/brainpy/_src/dnn/linear.py
@@ -4,6 +4,7 @@
from typing import Dict, Optional, Union, Callable
import jax
+import numpy as np
import jax.numpy as jnp
from brainpy import math as bm
@@ -63,8 +64,8 @@ def __init__(
num_out: int,
W_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(),
b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(),
- mode: bm.Mode = None,
- name: str = None,
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
):
super(Dense, self).__init__(mode=mode, name=name)
@@ -642,7 +643,7 @@ def __init__(
num_out: int,
prob: float,
weight: float,
- seed: int,
+ seed: Optional[int] = None,
sharding: Optional[Sharding] = None,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
@@ -654,7 +655,7 @@ def __init__(
self.prob = prob
self.sharding = sharding
self.transpose = transpose
- self.seed = seed
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
self.atomic = atomic
self.num_in = num_in
self.num_out = num_out
@@ -723,7 +724,7 @@ def __init__(
prob: float,
w_low: float,
w_high: float,
- seed: int,
+ seed: Optional[int] = None,
sharding: Optional[Sharding] = None,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
@@ -735,7 +736,7 @@ def __init__(
self.prob = prob
self.sharding = sharding
self.transpose = transpose
- self.seed = seed
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
self.atomic = atomic
self.num_in = num_in
self.num_out = num_out
@@ -803,7 +804,7 @@ def __init__(
prob: float,
w_mu: float,
w_sigma: float,
- seed: int,
+ seed: Optional[int] = None,
sharding: Optional[Sharding] = None,
transpose: bool = False,
atomic: bool = False,
@@ -815,7 +816,7 @@ def __init__(
self.prob = prob
self.sharding = sharding
self.transpose = transpose
- self.seed = seed
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
self.atomic = atomic
self.num_in = num_in
self.num_out = num_out
@@ -881,7 +882,7 @@ def __init__(
num_out: int,
prob: float,
weight: float,
- seed: int,
+ seed: Optional[int] = None,
sharding: Optional[Sharding] = None,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
@@ -893,7 +894,7 @@ def __init__(
self.prob = prob
self.sharding = sharding
self.transpose = transpose
- self.seed = seed
+ self.seed = np.random.randint(0, 1000000) if seed is None else seed
self.atomic = atomic
self.num_in = num_in
self.num_out = num_out
@@ -962,7 +963,7 @@ def __init__(
prob: float,
w_low: float,
w_high: float,
- seed: int,
+ seed: Optional[int] = None,
sharding: Optional[Sharding] = None,
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
@@ -974,7 +975,7 @@ def __init__(
self.prob = prob
self.sharding = sharding
self.transpose = transpose
- self.seed = seed
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
self.atomic = atomic
self.num_in = num_in
self.num_out = num_out
@@ -1042,7 +1043,7 @@ def __init__(
prob: float,
w_mu: float,
w_sigma: float,
- seed: int,
+ seed: Optional[int] = None,
sharding: Optional[Sharding] = None,
transpose: bool = False,
atomic: bool = False,
@@ -1054,7 +1055,7 @@ def __init__(
self.prob = prob
self.sharding = sharding
self.transpose = transpose
- self.seed = seed
+ self.seed = np.random.randint(0, 100000) if seed is None else seed
self.atomic = atomic
self.num_in = num_in
self.num_out = num_out
diff --git a/brainpy/_src/dyn/others/input.py b/brainpy/_src/dyn/others/input.py
index 0bf8a2b76..10ee8ab2c 100644
--- a/brainpy/_src/dyn/others/input.py
+++ b/brainpy/_src/dyn/others/input.py
@@ -40,11 +40,11 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
):
- super(InputGroup, self).__init__(name=name,
- sharding=sharding,
- size=size,
- keep_size=keep_size,
- mode=mode)
+ super().__init__(name=name,
+ sharding=sharding,
+ size=size,
+ keep_size=keep_size,
+ mode=mode)
def update(self, x):
return x
@@ -74,11 +74,11 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
):
- super(OutputGroup, self).__init__(name=name,
- sharding=sharding,
- size=size,
- keep_size=keep_size,
- mode=mode)
+ super().__init__(name=name,
+ sharding=sharding,
+ size=size,
+ keep_size=keep_size,
+ mode=mode)
def update(self, x):
return x
@@ -130,11 +130,11 @@ def __init__(
mode: Optional[bm.Mode] = None,
need_sort: bool = True,
):
- super(SpikeTimeGroup, self).__init__(size=size,
- sharding=sharding,
- name=name,
- keep_size=keep_size,
- mode=mode)
+ super().__init__(size=size,
+ sharding=sharding,
+ name=name,
+ keep_size=keep_size,
+ mode=mode)
# parameters
if keep_size:
@@ -202,11 +202,11 @@ def __init__(
mode: Optional[bm.Mode] = None,
seed=None,
):
- super(PoissonGroup, self).__init__(size=size,
- sharding=sharding,
- name=name,
- keep_size=keep_size,
- mode=mode)
+ super().__init__(size=size,
+ sharding=sharding,
+ name=name,
+ keep_size=keep_size,
+ mode=mode)
if seed is not None:
warnings.warn('')
diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/dyn/synapses/abstract_models.py
index 81cf954d5..24b690951 100644
--- a/brainpy/_src/dyn/synapses/abstract_models.py
+++ b/brainpy/_src/dyn/synapses/abstract_models.py
@@ -334,7 +334,8 @@ def add_current(self, inp):
self.g_decay += inp
def return_info(self):
- return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros)
+ return ReturnInfo(self.varshape, self.sharding, self.mode,
+ lambda shape: self.coeff * (self.g_decay - self.g_rise))
DualExponV2.__doc__ = DualExponV2.__doc__ % (pneu_doc,)
@@ -677,22 +678,21 @@ def update(self, pre_spike):
t = share.load('t')
dt = share.load('dt')
u, x = self.integral(self.u.value, self.x.value, t, dt)
- if pre_spike.dtype == jax.numpy.bool_:
- u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
- x = bm.where(pre_spike, x - u * self.x, x)
- else:
- u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
- x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
+ # if pre_spike.dtype == jax.numpy.bool_:
+ # u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
+ # x = bm.where(pre_spike, x - u * self.x, x)
+ # else:
+ # u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
+ # x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
+ u = pre_spike * self.U * (1 - self.u) + u
+ x = pre_spike * -u * self.x + x
self.x.value = x
self.u.value = u
return u * x
def return_info(self):
- return ReturnInfo(size=self.varshape,
- batch_or_mode=self.mode,
- axis_names=self.sharding,
- init=Constant(self.U))
+ return ReturnInfo(self.varshape, self.sharding, self.mode,
+ lambda shape: self.u * self.x)
STP.__doc__ = STP.__doc__ % (pneu_doc,)
-
diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py
index 42111dba0..93f9c0db8 100644
--- a/brainpy/_src/math/object_transform/jit.py
+++ b/brainpy/_src/math/object_transform/jit.py
@@ -405,13 +405,14 @@ def _make_jit_fun(
@wraps(fun)
def call_fun(self, *args, **kwargs):
- fun2 = partial(fun, self)
if jax.config.jax_disable_jit:
- return fun2(*args, **kwargs)
+ return fun(self, *args, **kwargs)
hash_v = hash(fun) + hash(self)
cache = get_stack_cache(hash_v) # TODO: better cache mechanism
if cache is None:
+ fun2 = partial(fun, self)
+
with jax.ensure_compile_time_eval():
if len(static_argnums) or len(static_argnames):
fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames)
diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py
index 7a10a8227..e461e691f 100644
--- a/brainpy/_src/math/object_transform/variables.py
+++ b/brainpy/_src/math/object_transform/variables.py
@@ -71,7 +71,7 @@ def dict_data(self) -> dict:
"""Get all data in the collected variables with a python dict structure."""
new_dict = dict()
for id_, elem in tuple(self.items()):
- new_dict[id_] = elem.value if isinstance(elem, Array) else elem
+ new_dict[id_] = elem._value if isinstance(elem, Array) else elem
return new_dict
def list_data(self) -> list:
diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py
index 8447e32e7..4e0c0e188 100644
--- a/brainpy/_src/mixin.py
+++ b/brainpy/_src/mixin.py
@@ -103,6 +103,14 @@ def __instancecheck__(self, instance):
def __class_getitem__(cls, item: type):
return ParamDescInit(item)
+ @property
+ def identifier(self):
+ return self._identifier
+
+ @identifier.setter
+ def identifier(self, value):
+ self._identifier = value
+
class AlignPost(MixIn):
"""Align post MixIn.
@@ -118,9 +126,26 @@ def add_current(self, *args, **kwargs):
@dataclass
class ReturnInfo:
size: Sequence[int]
- axis_names: Optional[Sequence[str]]
- batch_or_mode: Optional[Union[int, bm.Mode]]
- init: Callable
+ axis_names: Optional[Sequence[str]] = None
+ batch_or_mode: Optional[Union[int, bm.Mode]] = None
+ data: Union[Callable, bm.Array, jax.Array] = bm.zeros
+
+ def get_data(self):
+ if isinstance(self.data, Callable):
+ if isinstance(self.batch_or_mode, int):
+ size = (self.batch_or_mode,) + tuple(self.size)
+ elif isinstance(self.batch_or_mode, bm.NonBatchingMode):
+ size = tuple(self.size)
+ elif isinstance(self.batch_or_mode, bm.BatchingMode):
+ size = (self.batch_or_mode.batch_size,) + tuple(self.size)
+ else:
+ size = tuple(self.size)
+ init = self.data(size)
+ elif isinstance(self.data, (bm.Array, jax.Array)):
+ init = self.data
+ else:
+ raise ValueError
+ return init
class AutoDelaySupp(MixIn):
@@ -493,12 +518,13 @@ def __subclasscheck__(self, subclass):
@_SpecialForm
def JointType(self, parameters):
- """Joint type; JointType[X, Y] means either X or Y.
+ """Joint type; JointType[X, Y] means both X and Y.
+
+ To define a union, use e.g. Union[int, str].
- To define a union, use e.g. Union[int, str]. Details:
+ Details:
- The arguments must be types and there must be at least one.
- - None as an argument is a special case and is replaced by
- type(None).
+ - None as an argument is a special case and is replaced by `type(None)`.
- Unions of unions are flattened, e.g.::
JointType[JointType[int, str], float] == JointType[int, str, float]
@@ -519,7 +545,7 @@ def JointType(self, parameters):
- You can use Optional[X] as a shorthand for JointType[X, None].
"""
if parameters == ():
- raise TypeError("Cannot take a Union of no types.")
+ raise TypeError("Cannot take a Joint of no types.")
if not isinstance(parameters, tuple):
parameters = (parameters,)
msg = "JointType[arg, ...]: each arg must be a type."
@@ -540,10 +566,10 @@ class _SpecialForm2(_SpecialForm, _root=True):
def __getitem__(self, parameters):
if self._name == 'JointType':
if parameters == ():
- raise TypeError("Cannot take a Union of no types.")
+ raise TypeError("Cannot take a Joint of no types.")
if not isinstance(parameters, tuple):
parameters = (parameters,)
- msg = "Union[arg, ...]: each arg must be a type."
+ msg = "JointType[arg, ...]: each arg must be a type."
parameters = tuple(_type_check(p, msg) for p in parameters)
parameters = _remove_dups_flatten(parameters)
if len(parameters) == 1:
@@ -555,12 +581,14 @@ def __getitem__(self, parameters):
JointType = _SpecialForm2(
'JointType',
- doc="""Joint type; JointType[X, Y] means either X or Y.
+ doc="""Joint type; JointType[X, Y] means both X and Y.
- To define a union, use e.g. JointType[int, str]. Details:
+ To define a union, use e.g. JointType[int, str].
+
+ Details:
+
- The arguments must be types and there must be at least one.
- - None as an argument is a special case and is replaced by
- type(None).
+ - None as an argument is a special case and is replaced by `type(None)`.
- Unions of unions are flattened, e.g.::
JointType[JointType[int, str], float] == JointType[int, str, float]
diff --git a/brainpy/check.py b/brainpy/check.py
index 65756d1c9..a1c780106 100644
--- a/brainpy/check.py
+++ b/brainpy/check.py
@@ -507,15 +507,11 @@ def is_instance(
name: str
The checking target name.
"""
- if isinstance(supported_types, type):
- supported_types = (supported_types,)
- if not isinstance(supported_types, (tuple, list)):
- raise TypeError(f'supported_types must be a tuple/list of type. But wwe got {type(supported_types)}')
- for smode in supported_types:
- assert isinstance(smode, type), f'supported_types must be a tuple/list of type. But wwe got {smode}'
+ if not name:
+ name = 'We'
if not isinstance(instance, supported_types):
- raise NotImplementedError(f"{name} does not support {instance}. We only support "
- f"{', '.join([mode.__name__ for mode in supported_types])}. ")
+ raise NotImplementedError(f"{name} expect to get an instance of {supported_types}."
+ f"But we got {type(instance)}. ")
return instance
diff --git a/brainpy/dyn/__init__.py b/brainpy/dyn/__init__.py
index b3272e45a..ab51a9c73 100644
--- a/brainpy/dyn/__init__.py
+++ b/brainpy/dyn/__init__.py
@@ -7,3 +7,4 @@
from .projections import *
from .others import *
from .outs import *
+from .compat import NeuGroup
diff --git a/brainpy/dyn/compat.py b/brainpy/dyn/compat.py
new file mode 100644
index 000000000..b7951ae01
--- /dev/null
+++ b/brainpy/dyn/compat.py
@@ -0,0 +1,10 @@
+
+from brainpy._src.dyn.base import NeuDyn
+
+__all__ = [
+ 'NeuGroup',
+]
+
+NeuGroup = NeuDyn
+
+
diff --git a/docs/conf.py b/docs/conf.py
index f584fb7a8..993d31a44 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -35,11 +35,7 @@
auto_generater.generate_brainpy_docs()
auto_generater.generate_integrators_doc()
auto_generater.generate_math_docs()
-# auto_generater.generate_channels_docs()
-# auto_generater.generate_layers_docs()
-# auto_generater.generate_neurons_docs()
-# auto_generater.generate_rates_docs()
-# auto_generater.generate_synapses_docs()
+auto_generater.generate_mixin_docs()
changelogs = [
diff --git a/examples/dynamics_simulation/COBA-v2.py b/examples/dynamics_simulation/COBA-v2.py
index 4087cdc64..03aa86c61 100644
--- a/examples/dynamics_simulation/COBA-v2.py
+++ b/examples/dynamics_simulation/COBA-v2.py
@@ -1,4 +1,5 @@
import brainpy as bp
+import brainpy.math as bm
neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
@@ -12,7 +13,7 @@ def __init__(self, num_exc, num_inh, inp=20.):
self.E = bp.dyn.LifRefLTC(num_exc, **neu_pars)
self.I = bp.dyn.LifRefLTC(num_inh, **neu_pars)
- self.E2I = bp.dyn.ProjAlignPre(
+ self.E2I = bp.dyn.ProjAlignPreMg1(
pre=self.E,
syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.),
delay=None,
@@ -20,7 +21,7 @@ def __init__(self, num_exc, num_inh, inp=20.):
out=bp.dyn.COBA(E=0.),
post=self.I,
)
- self.E2E = bp.dyn.ProjAlignPre(
+ self.E2E = bp.dyn.ProjAlignPreMg1(
pre=self.E,
syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.),
delay=None,
@@ -28,7 +29,7 @@ def __init__(self, num_exc, num_inh, inp=20.):
out=bp.dyn.COBA(E=0.),
post=self.E,
)
- self.I2E = bp.dyn.ProjAlignPre(
+ self.I2E = bp.dyn.ProjAlignPreMg1(
pre=self.I,
syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.),
delay=None,
@@ -36,7 +37,7 @@ def __init__(self, num_exc, num_inh, inp=20.):
out=bp.dyn.COBA(E=-80.),
post=self.E,
)
- self.I2I = bp.dyn.ProjAlignPre(
+ self.I2I = bp.dyn.ProjAlignPreMg1(
pre=self.I,
syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.),
delay=0.,
@@ -62,7 +63,7 @@ def __init__(self, num_exc, num_inh, inp=20.):
self.E = bp.dyn.LifRefLTC(num_exc, **neu_pars)
self.I = bp.dyn.LifRefLTC(num_inh, **neu_pars)
- self.E2E = bp.dyn.ProjAlignPost(
+ self.E2E = bp.dyn.ProjAlignPostMg2(
pre=self.E,
delay=None,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.E.num, post=self.E.num), 0.6),
@@ -70,7 +71,7 @@ def __init__(self, num_exc, num_inh, inp=20.):
out=bp.dyn.COBA.desc(E=0.),
post=self.E,
)
- self.E2I = bp.dyn.ProjAlignPost(
+ self.E2I = bp.dyn.ProjAlignPostMg2(
pre=self.E,
delay=None,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.E.num, post=self.I.num), 0.6),
@@ -78,7 +79,7 @@ def __init__(self, num_exc, num_inh, inp=20.):
out=bp.dyn.COBA.desc(E=0.),
post=self.I,
)
- self.I2E = bp.dyn.ProjAlignPost(
+ self.I2E = bp.dyn.ProjAlignPostMg2(
pre=self.I,
delay=None,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.I.num, post=self.E.num), 6.7),
@@ -86,7 +87,7 @@ def __init__(self, num_exc, num_inh, inp=20.):
out=bp.dyn.COBA.desc(E=-80.),
post=self.E,
)
- self.I2I = bp.dyn.ProjAlignPost(
+ self.I2I = bp.dyn.ProjAlignPostMg2(
pre=self.I,
delay=None,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.I.num, post=self.I.num), 6.7),
@@ -147,10 +148,13 @@ def run3():
def run1():
- net = EICOBA_PostAlign(3200, 800)
- runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
- print(runner.run(100., eval_time=True))
- bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
+ with bm.environment(mode=bm.BatchingMode(10)):
+ net = EICOBA_PostAlign(3200, 800)
+ runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
+ print(runner.run(100., eval_time=True))
+ print(runner.mon['E.spike'].shape)
+ print(runner.mon['ts'].shape)
+ bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'][0], show=True)
def run2():
diff --git a/examples/dynamics_simulation/COBA.py b/examples/dynamics_simulation/COBA.py
deleted file mode 100644
index 4818c3ab9..000000000
--- a/examples/dynamics_simulation/COBA.py
+++ /dev/null
@@ -1,129 +0,0 @@
-import brainpy as bp
-import brainpy.math as bm
-from jax import pmap
-
-bm.set_host_device_count(20)
-
-
-class EINet(bp.DynamicalSystem):
- def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None):
- super().__init__()
-
- self.bg_exc = e_input
- self.bg_inh = i_input
-
- # network size
- num_exc = int(3200 * scale)
- num_inh = int(800 * scale)
-
- # neurons
- pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.), input_var=False)
- self.E = bp.neurons.LIF(num_exc, **pars)
- self.I = bp.neurons.LIF(num_inh, **pars)
-
- # synapses
- we = 0.6 / scale # excitatory synaptic weight (voltage)
- wi = 6.7 / scale # inhibitory synaptic weight
- self.E2E = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.E.size, post=self.E.size),
- g_max=we, tau=5., out=bp.experimental.COBA(E=0.), comp_method='dense'
- )
- self.E2I = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.E.size, post=self.I.size, ),
- g_max=we, tau=5., out=bp.experimental.COBA(E=0.), comp_method='dense'
- )
- self.I2E = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.I.size, post=self.E.size),
- g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.), comp_method='dense'
- )
- self.I2I = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.I.size, post=self.I.size),
- g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.), comp_method='dense'
- )
- self.delayE = bp.Delay(self.E.spike, entries={'E': delay})
- self.delayI = bp.Delay(self.I.spike, entries={'I': delay})
-
- def update(self):
- e_spike = self.delayE.at('E')
- i_spike = self.delayI.at('I')
- e_inp = self.E2E(e_spike, self.E.V) + self.I2E(i_spike, self.E.V) + self.bg_exc
- i_inp = self.I2I(i_spike, self.I.V) + self.E2I(e_spike, self.I.V) + self.bg_inh
- self.delayE(self.E(e_inp))
- self.delayI(self.I(i_inp))
-
-
-class EINetv2(bp.DynamicalSystem):
- def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None):
- super().__init__()
-
- self.bg_exc = e_input
- self.bg_inh = i_input
-
- # network size
- num_exc = int(3200 * scale)
- num_inh = int(800 * scale)
-
- # neurons
- pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.), input_var=False)
- self.E = bp.neurons.LIF(num_exc, **pars)
- self.I = bp.neurons.LIF(num_inh, **pars)
-
- # synapses
- we = 0.6 / scale # excitatory synaptic weight (voltage)
- wi = 6.7 / scale # inhibitory synaptic weight
- self.E2E = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.E.size, post=self.E.size),
- g_max=we, tau=5., out=bp.experimental.COBA(E=0.)
- )
- self.E2I = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.E.size, post=self.I.size, ),
- g_max=we, tau=5., out=bp.experimental.COBA(E=0.)
- )
- self.I2E = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.I.size, post=self.E.size),
- g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.)
- )
- self.I2I = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.I.size, post=self.I.size),
- g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.)
- )
- bp.share.save('E-spike', bp.Delay(self.E.spike, entries={'E': delay}))
- bp.share.save('I-spike', bp.Delay(self.I.spike, entries={'I': delay}))
-
- def update(self):
- e_spike = bp.share.load('E-spike').at('E')
- i_spike = bp.share.load('I-spike').at('I')
- e_inp = self.E2E(e_spike, self.E.V) + self.I2E(i_spike, self.E.V) + self.bg_exc
- i_inp = self.I2I(i_spike, self.I.V) + self.E2I(e_spike, self.I.V) + self.bg_inh
- self.E(e_inp)
- self.I(i_inp)
-
-
-# simulation
-net = EINet(delay=0., scale=1.)
-runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
-runner.run(100.)
-# print(r)
-bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
-
-# @pmap
-# def f2(I):
-# net = EINet(delay=0., scale=5., e_input=I, i_input=I)
-# # net = EINetv2(delay=0., scale=2.)
-# runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, numpy_mon_after_run=False)
-# runner.run(10000.)
-# return runner.mon
-# # print(r)
-# # bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
-#
-#
-# print(f2(bm.ones(20) * 20.))
-
-
-
-
-
-
-
From 329b6e79c10edbf94a83840fcd258076bbcfe65a Mon Sep 17 00:00:00 2001
From: chaoming
Date: Wed, 19 Jul 2023 23:26:20 +0800
Subject: [PATCH 053/326] add `update()` deprecation warning
---
brainpy/_src/analysis/highdim/slow_points.py | 18 +++----
brainpy/_src/dynsys.py | 52 ++++++++++++++++++--
brainpy/dyn/__init__.py | 1 +
3 files changed, 59 insertions(+), 12 deletions(-)
diff --git a/brainpy/_src/analysis/highdim/slow_points.py b/brainpy/_src/analysis/highdim/slow_points.py
index 4c0b82a87..55a7b3207 100644
--- a/brainpy/_src/analysis/highdim/slow_points.py
+++ b/brainpy/_src/analysis/highdim/slow_points.py
@@ -14,8 +14,8 @@
from brainpy import optim, losses
from brainpy._src.analysis import utils, base, constants
from brainpy._src.dynsys import DynamicalSystem
+from brainpy._src.context import share
from brainpy._src.runners import check_and_format_inputs, _f_ops
-from brainpy._src.tools.dicts import DotDict
from brainpy.errors import AnalyzerError, UnsupportedError
from brainpy.types import ArrayType
@@ -123,7 +123,7 @@ def __init__(
f_loss_batch: Callable = None,
fun_inputs: Callable = None,
):
- super(SlowPointFinder, self).__init__()
+ super().__init__()
# static arguments
if not isinstance(args, tuple):
@@ -636,11 +636,11 @@ def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=False)
'L': L})
return decompositions
- def _step_func_input(self, shared):
+ def _step_func_input(self):
if self._inputs is None:
return
elif callable(self._inputs):
- self._inputs(shared)
+ self._inputs(share.get_shargs())
else:
for ops, values in self._inputs['fixed'].items():
for var, data in values:
@@ -650,7 +650,7 @@ def _step_func_input(self, shared):
raise UnsupportedError
for ops, values in self._inputs['functional'].items():
for var, data in values:
- _f_ops(ops, var, data(shared))
+ _f_ops(ops, var, data(share.get_shargs()))
for ops, values in self._inputs['iterated'].items():
if len(values) > 0:
raise UnsupportedError
@@ -732,9 +732,10 @@ def _generate_ds_cell_function(
):
if dt is None: dt = bm.get_dt()
if t is None: t = 0.
- shared = DotDict(t=t, dt=dt, i=0)
def f_cell(h: Dict):
+ share.save(t=t, i=0, dt=dt)
+
# update target variables
for k, v in self.target_vars.items():
v.value = (bm.asarray(h[k], dtype=v.dtype)
@@ -747,11 +748,10 @@ def f_cell(h: Dict):
# add inputs
target.clear_input()
- self._step_func_input(shared)
+ self._step_func_input()
# call update functions
- args = (shared,) + self.args
- target(*args)
+ target(*self.args)
# get new states
new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis))
diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py
index 02624815a..f14302040 100644
--- a/brainpy/_src/dynsys.py
+++ b/brainpy/_src/dynsys.py
@@ -3,6 +3,7 @@
import collections
import gc
import inspect
+import warnings
from typing import Union, Dict, Callable, Sequence, Optional, Any
import numpy as np
@@ -28,6 +29,21 @@
SLICE_VARS = 'slice_vars'
+_update_deprecate_msg = '''
+From brainpy>=2.4.3, update() function no longer needs to receive a global shared argument.
+
+Instead of using:
+
+ def update(self, tdi, *args, **kwagrs):
+ ...
+
+Please use:
+
+ def update(self, *args, **kwagrs):
+ t = bp.share['t']
+ ...
+'''
+
def not_pass_shared(func: Callable):
"""Label the update function as the one without passing shared arguments.
@@ -160,13 +176,38 @@ def clear_input(self):
pass
def step_run(self, i, *args, **kwargs):
+ """The step run function.
+
+ This function can be directly applied to run the dynamical system.
+ Particularly, ``i`` denotes the running index.
+
+ Args:
+ i: The current running index.
+ *args: The arguments of ``update()`` function.
+ **kwargs: The arguments of ``update()`` function.
+
+ Returns:
+ out: The update function returns.
+ """
global share
if share is None:
from brainpy._src.context import share
share.save(i=i, t=i * bm.dt)
return self.update(*args, **kwargs)
- jit_step_run = bm.cls_jit(step_run, inline=True)
+ @bm.cls_jit(inline=True)
+ def jit_step_run(self, i, *args, **kwargs):
+ """The jitted step function for running.
+
+ Args:
+ i: The current running index.
+ *args: The arguments of ``update()`` function.
+ **kwargs: The arguments of ``update()`` function.
+
+ Returns:
+ out: The update function returns.
+ """
+ return self.step_run(i, *args, **kwargs)
@property
def mode(self) -> bm.Mode:
@@ -189,19 +230,20 @@ def _compatible_update(self, *args, **kwargs):
if len(update_args) and update_args[0].name in ['tdi', 'sh', 'sha']:
if len(args) > 0:
- if isinstance(args[0], dict):
+ if isinstance(args[0], dict) and all([bm.isscalar(v) for v in args[0].values()]):
# define:
# update(tdi, *args, **kwargs)
# call:
# update(tdi, *args, **kwargs)
ret = update_fun(*args, **kwargs)
- # TODO: deprecation
+ warnings.warn(_update_deprecate_msg, UserWarning)
else:
# define:
# update(tdi, *args, **kwargs)
# call:
# update(*args, **kwargs)
ret = update_fun(share.get_shargs(), *args, **kwargs)
+ warnings.warn(_update_deprecate_msg, UserWarning)
else:
if update_args[0].name in kwargs:
# define:
@@ -209,12 +251,14 @@ def _compatible_update(self, *args, **kwargs):
# call:
# update(tdi=??, **kwargs)
ret = update_fun(**kwargs)
+ warnings.warn(_update_deprecate_msg, UserWarning)
else:
# define:
# update(tdi, *args, **kwargs)
# call:
# update(**kwargs)
ret = update_fun(share.get_shargs(), *args, **kwargs)
+ warnings.warn(_update_deprecate_msg, UserWarning)
return ret
try:
@@ -230,6 +274,7 @@ def _compatible_update(self, *args, **kwargs):
# update(*args, **kwargs)
share.save(**args[0])
ret = update_fun(*args[1:], **kwargs)
+ warnings.warn(_update_deprecate_msg, UserWarning)
return ret
else:
# user define ``update()`` function which receives the shared argument,
@@ -240,6 +285,7 @@ def _compatible_update(self, *args, **kwargs):
# as
# update(tdi, *args, **kwargs)
ret = update_fun(share.get_shargs(), *args, **kwargs)
+ warnings.warn(_update_deprecate_msg, UserWarning)
return ret
else:
return update_fun(*args, **kwargs)
diff --git a/brainpy/dyn/__init__.py b/brainpy/dyn/__init__.py
index ab51a9c73..297c0c50b 100644
--- a/brainpy/dyn/__init__.py
+++ b/brainpy/dyn/__init__.py
@@ -7,4 +7,5 @@
from .projections import *
from .others import *
from .outs import *
+from .rates import *
from .compat import NeuGroup
From cce047c45922247753399caf3c7a29546136110f Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 21 Jul 2023 10:19:59 +0800
Subject: [PATCH 054/326] update examples
---
.../2d_fitzhugh_nagumo_model.py | 5 +-
.../dynamics_analysis/2d_mean_field_QIF.py | 7 +--
.../dynamics_analysis/3d_reduced_trn_model.py | 4 +-
.../dynamics_analysis/highdim_RNN_Analysis.py | 4 +-
.../{COBA-v2.py => COBA.py} | 26 ++++++----
examples/dynamics_simulation/hh_model.py | 48 ++++---------------
.../dynamics_simulation/multi_scale_COBAHH.py | 7 +--
.../whole_brain_simulation_with_fhn.py | 10 ++--
...ole_brain_simulation_with_sl_oscillator.py | 12 ++---
.../dynamics_training/Song_2016_EI_RNN.py | 2 +-
examples/training_ann_models/mnist-cnn.py | 26 +++++-----
examples/training_ann_models/mnist_ResNet.py | 34 ++++++-------
12 files changed, 83 insertions(+), 102 deletions(-)
rename examples/dynamics_simulation/{COBA-v2.py => COBA.py} (95%)
diff --git a/examples/dynamics_analysis/2d_fitzhugh_nagumo_model.py b/examples/dynamics_analysis/2d_fitzhugh_nagumo_model.py
index b1dd0e655..73af38f2e 100644
--- a/examples/dynamics_analysis/2d_fitzhugh_nagumo_model.py
+++ b/examples/dynamics_analysis/2d_fitzhugh_nagumo_model.py
@@ -33,8 +33,9 @@ def dw(w, t, V, a=0.7, b=0.8):
self.int_V = bp.odeint(dV, method=method)
self.int_w = bp.odeint(dw, method=method)
- def update(self, tdi):
- t, dt = tdi['t'], tdi['dt']
+ def update(self):
+ t = bp.share['t']
+ dt = bp.share['dt']
self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt)
self.w.value = self.int_w(self.w, t, self.V, self.a, self.b, dt)
self.Iext[:] = 0.
diff --git a/examples/dynamics_analysis/2d_mean_field_QIF.py b/examples/dynamics_analysis/2d_mean_field_QIF.py
index 467bc6118..28be6a51d 100644
--- a/examples/dynamics_analysis/2d_mean_field_QIF.py
+++ b/examples/dynamics_analysis/2d_mean_field_QIF.py
@@ -14,7 +14,7 @@ class MeanFieldQIF(bp.DynamicalSystem):
"""
def __init__(self, method='exp_auto'):
- super(MeanFieldQIF, self).__init__()
+ super().__init__()
# parameters
self.tau = 1. # the population time constant
@@ -38,8 +38,9 @@ def dv(v, t, r, Iext=0., eta=-5.0):
self.int_r = bp.odeint(dr, method=method)
self.int_v = bp.odeint(dv, method=method)
- def update(self, tdi):
- t, dt = tdi['t'], tdi['dt']
+ def update(self):
+ t = bp.share['t']
+ dt = bp.share['dt']
self.r.value = self.int_r(self.r, t, self.v, self.delta, dt)
self.v.value = self.int_v(self.v, t, self.r, self.Iext, self.eta, dt)
self.Iext[:] = 0.
diff --git a/examples/dynamics_analysis/3d_reduced_trn_model.py b/examples/dynamics_analysis/3d_reduced_trn_model.py
index fde3da625..90dd20c49 100644
--- a/examples/dynamics_analysis/3d_reduced_trn_model.py
+++ b/examples/dynamics_analysis/3d_reduced_trn_model.py
@@ -7,9 +7,9 @@
bp.math.set_platform('cpu')
-class ReducedTRNModel(bp.NeuDyn):
+class ReducedTRNModel(bp.dyn.NeuDyn):
def __init__(self, size, name=None, T=36., method='rk4'):
- super(ReducedTRNModel, self).__init__(size=size, name=name)
+ super().__init__(size=size, name=name)
self.IT_th = -3.
self.b = 0.5
diff --git a/examples/dynamics_analysis/highdim_RNN_Analysis.py b/examples/dynamics_analysis/highdim_RNN_Analysis.py
index 75b844247..cd9d76829 100644
--- a/examples/dynamics_analysis/highdim_RNN_Analysis.py
+++ b/examples/dynamics_analysis/highdim_RNN_Analysis.py
@@ -26,7 +26,7 @@ def __init__(
w_rr=bp.init.KaimingNormal(scale=1.),
w_ro=bp.init.KaimingNormal(scale=1.)
):
- super(RNNNet, self).__init__()
+ super().__init__()
self.tau = 100
self.num_input = num_input
@@ -64,7 +64,7 @@ def cell(self, x, h):
def readout(self, h):
return h @ self.w_ro + self.b_ro
- def update(self, sha, x):
+ def update(self, x):
self.h.value = self.cell(x, self.h.value)
return self.readout(self.h.value)
diff --git a/examples/dynamics_simulation/COBA-v2.py b/examples/dynamics_simulation/COBA.py
similarity index 95%
rename from examples/dynamics_simulation/COBA-v2.py
rename to examples/dynamics_simulation/COBA.py
index 03aa86c61..043ede354 100644
--- a/examples/dynamics_simulation/COBA-v2.py
+++ b/examples/dynamics_simulation/COBA.py
@@ -140,12 +140,6 @@ def __init__(self, scale=1.0, method='exp_auto'):
# bm.set_host_device_count(num_device)
# bm.sharding.set(mesh_axes=(bp.dyn.PNEU_AXIS,), mesh_shape=(num_device, ))
-def run3():
- net = EICOBA_PreAlign(3200, 800)
- runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
- print(runner.run(100., eval_time=True))
- bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
-
def run1():
with bm.environment(mode=bm.BatchingMode(10)):
@@ -167,7 +161,23 @@ def run2():
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
+def run3():
+ net = EICOBA_PreAlign(3200, 800)
+ runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
+ print(runner.run(100., eval_time=True))
+ bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
+
+
+
+def run4():
+ net = EICOBA_PostAlign(3200, 800)
+ runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
+ print(runner.run(100., eval_time=True))
+ bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
+
+
if __name__ == '__main__':
- # run1()
+ run1()
run2()
- # run3()
+ run3()
+ run4()
diff --git a/examples/dynamics_simulation/hh_model.py b/examples/dynamics_simulation/hh_model.py
index 06b435595..6b64a6c10 100644
--- a/examples/dynamics_simulation/hh_model.py
+++ b/examples/dynamics_simulation/hh_model.py
@@ -11,32 +11,27 @@
class HH(bp.dyn.CondNeuGroup):
def __init__(self, size):
- super().__init__(size, keep_size=True)
+ super().__init__(size)
- self.INa = bp.channels.INa_HH1952(size, keep_size=True)
- self.IK = bp.channels.IK_HH1952(size, keep_size=True)
- self.IL = bp.channels.IL(size, E=-54.387, g_max=0.03, keep_size=True)
+ self.INa = bp.channels.INa_HH1952(size)
+ self.IK = bp.channels.IK_HH1952(size)
+ self.IL = bp.channels.IL(size, E=-54.387, g_max=0.03)
class HHv2(bp.dyn.CondNeuGroupLTC):
def __init__(self, size):
- super().__init__(size, keep_size=True)
+ super().__init__(size)
self.Na = bp.dyn.SodiumFixed(size, E=50.)
- self.Na.add(ina=bp.dyn.INa_HH1952v2(size, keep_size=True))
+ self.Na.add_elem(ina=bp.dyn.INa_HH1952v2(size))
self.K = bp.dyn.PotassiumFixed(size, E=50.)
- self.K.add(ik=bp.dyn.IK_HH1952v2(size, keep_size=True))
-
- self.IL = bp.dyn.IL(size, E=-54.387, g_max=0.03, keep_size=True)
-
- self.KNa = bp.dyn.mixs(self.Na, self.K)
- self.KNa.add()
-
-
-
+ self.K.add_elem(ik=bp.dyn.IK_HH1952v2(size))
+ self.IL = bp.dyn.IL(size, E=-54.387, g_max=0.03)
+ self.KNa = bp.dyn.MixIons(self.Na, self.K)
+ self.KNa.add_elem()
# hh = HH(1)
@@ -52,26 +47,3 @@ def __init__(self, size):
#
# bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
-
-hh = HH((20, 10000))
-variables = hh.vars().unique()
-
-
-iis = np.arange(1000000000)
-
-def f(i):
- bp.share.save(i=i, t=i * bm.get_dt(), dt=bm.get_dt())
- hh(5.)
-
-
-@pmap
-def run(vars):
- for v, d in vars.items():
- variables[v]._value = d
- bm.for_loop(f, bm.arange(1000000000))
- print('Compiling End')
- return hh.spike
-
-
-r = run(variables.dict())
-print(r.shape)
diff --git a/examples/dynamics_simulation/multi_scale_COBAHH.py b/examples/dynamics_simulation/multi_scale_COBAHH.py
index cd1e6b355..14bea66fe 100644
--- a/examples/dynamics_simulation/multi_scale_COBAHH.py
+++ b/examples/dynamics_simulation/multi_scale_COBAHH.py
@@ -7,12 +7,9 @@
import brainpy as bp
import brainpy.math as bm
-from brainpy.channels import INa_TM1991, IL
-from brainpy.synapses import Exponential, COBA
from brainpy.connect import FixedProb
from jax import vmap
-comp_method = 'sparse'
area_names = ['V1', 'V2', 'V4', 'TEO', 'TEpd']
@@ -47,8 +44,8 @@ class HH(bp.CondNeuGroup):
def __init__(self, size):
super(HH, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.))
self.IK = IK(size, g_max=30., V_sh=-63.)
- self.INa = INa_TM1991(size, g_max=100., V_sh=-63.)
- self.IL = IL(size, E=-60., g_max=0.05)
+ self.INa = bp.dyn.INa_TM1991(size, g_max=100., V_sh=-63.)
+ self.IL = bp.dyn.IL(size, E=-60., g_max=0.05)
class Network(bp.Network):
diff --git a/examples/dynamics_simulation/whole_brain_simulation_with_fhn.py b/examples/dynamics_simulation/whole_brain_simulation_with_fhn.py
index acc530986..3f1be523b 100644
--- a/examples/dynamics_simulation/whole_brain_simulation_with_fhn.py
+++ b/examples/dynamics_simulation/whole_brain_simulation_with_fhn.py
@@ -21,7 +21,7 @@ def bifurcation_analysis():
pp.show_figure()
-class Network(bp.Network):
+class Network(bp.DynSysGroup):
def __init__(self, signal_speed=20.):
super(Network, self).__init__()
@@ -36,12 +36,12 @@ def __init__(self, signal_speed=20.):
delay_mat = bm.asarray(delay_mat)
bm.fill_diagonal(delay_mat, 0)
- self.fhn = bp.rates.FHN(
+ self.fhn = bp.dyn.FHN(
80,
x_ou_sigma=0.01,
y_ou_sigma=0.01,
)
- self.coupling = bp.synapses.DiffusiveCoupling(
+ self.coupling = bp.dyn.DiffusiveCoupling(
self.fhn.x,
self.fhn.x,
var_to_output=self.fhn.input,
@@ -95,5 +95,5 @@ def net_analysis():
if __name__ == '__main__':
# bifurcation_analysis()
- # net_simulation()
- net_analysis()
+ net_simulation()
+ # net_analysis()
diff --git a/examples/dynamics_simulation/whole_brain_simulation_with_sl_oscillator.py b/examples/dynamics_simulation/whole_brain_simulation_with_sl_oscillator.py
index b7f3b45c3..b2cbdaacd 100644
--- a/examples/dynamics_simulation/whole_brain_simulation_with_sl_oscillator.py
+++ b/examples/dynamics_simulation/whole_brain_simulation_with_sl_oscillator.py
@@ -10,7 +10,7 @@
def bifurcation_analysis():
- model = bp.rates.StuartLandauOscillator(1, method='exp_auto')
+ model = bp.dyn.StuartLandauOscillator(1, method='exp_auto')
pp = bp.analysis.Bifurcation2D(
model,
target_vars={'x': [-2, 2], 'y': [-2, 2]},
@@ -22,7 +22,7 @@ def bifurcation_analysis():
pp.show_figure()
-class Network(bp.Network):
+class Network(bp.DynSysGroup):
def __init__(self, noise=0.14):
super(Network, self).__init__()
@@ -35,8 +35,8 @@ def __init__(self, noise=0.14):
bm.fill_diagonal(conn_mat, 0)
gc = 0.6 # global coupling strength
- self.sl = bp.rates.StuartLandauOscillator(80, x_ou_sigma=noise, y_ou_sigma=noise)
- self.coupling = bp.synapses.DiffusiveCoupling(
+ self.sl = bp.dyn.StuartLandauOscillator(80, x_ou_sigma=noise, y_ou_sigma=noise)
+ self.coupling = bp.dyn.DiffusiveCoupling(
self.sl.x, self.sl.x,
var_to_output=self.sl.input,
conn_mat=conn_mat * gc
@@ -87,6 +87,6 @@ def net_analysis():
if __name__ == '__main__':
- bifurcation_analysis()
+ # bifurcation_analysis()
simulation()
- net_analysis()
+ # net_analysis()
diff --git a/examples/dynamics_training/Song_2016_EI_RNN.py b/examples/dynamics_training/Song_2016_EI_RNN.py
index e4a19ba7b..f3aef2aeb 100644
--- a/examples/dynamics_training/Song_2016_EI_RNN.py
+++ b/examples/dynamics_training/Song_2016_EI_RNN.py
@@ -27,7 +27,7 @@ def __init__(
w_rr=bp.init.KaimingUniform(scale=1.),
w_ro=bp.init.KaimingUniform(scale=1.)
):
- super(EI_RNN, self).__init__()
+ super().__init__()
# parameters
self.tau = 100
diff --git a/examples/training_ann_models/mnist-cnn.py b/examples/training_ann_models/mnist-cnn.py
index 602191156..96b9b0ccd 100644
--- a/examples/training_ann_models/mnist-cnn.py
+++ b/examples/training_ann_models/mnist-cnn.py
@@ -10,20 +10,20 @@
class FeedForwardModel(bp.DynamicalSystem):
def __init__(self):
super(FeedForwardModel, self).__init__()
- self.conv1 = bp.layers.Conv2d(1, 32, kernel_size=(3, 3), strides=(1, 1), padding='SAME')
- self.pool = bp.layers.MaxPool(2, 2, channel_axis=-1)
- self.conv2 = bp.layers.Conv2d(32, 64, kernel_size=(3, 3), strides=(1, 1), padding='SAME')
- self.fc1 = bp.layers.Dense(64 * 7 * 7, 1024)
- self.fc2 = bp.layers.Dense(1024, 512)
- self.fc3 = bp.layers.Dense(512, 10)
-
- def update(self, s, x):
- x = self.pool(s, bm.relu(self.conv1(s, x)))
- x = self.pool(s, bm.relu(self.conv2(s, x)))
+ self.conv1 = bp.dnn.Conv2d(1, 32, kernel_size=(3, 3), strides=(1, 1), padding='SAME')
+ self.pool = bp.dnn.MaxPool(2, 2, channel_axis=-1)
+ self.conv2 = bp.dnn.Conv2d(32, 64, kernel_size=(3, 3), strides=(1, 1), padding='SAME')
+ self.fc1 = bp.dnn.Dense(64 * 7 * 7, 1024)
+ self.fc2 = bp.dnn.Dense(1024, 512)
+ self.fc3 = bp.dnn.Dense(512, 10)
+
+ def update(self, x):
+ x = self.pool(bm.relu(self.conv1(x)))
+ x = self.pool(bm.relu(self.conv2(x)))
x = x.reshape(-1, 64 * 7 * 7)
- x = bm.relu(self.fc1(s, x))
- x = bm.relu(self.fc2(s, x))
- x = self.fc3(s, x)
+ x = bm.relu(self.fc1(x))
+ x = bm.relu(self.fc2(x))
+ x = self.fc3(x)
return x
diff --git a/examples/training_ann_models/mnist_ResNet.py b/examples/training_ann_models/mnist_ResNet.py
index 9a74ddbb9..210aa1ea9 100644
--- a/examples/training_ann_models/mnist_ResNet.py
+++ b/examples/training_ann_models/mnist_ResNet.py
@@ -39,10 +39,10 @@ def __init__(self, in_planes, planes, stride=1, is_last=False):
bp.layers.BatchNorm2D(self.expansion * planes)
)
- def update(self, s, x):
- out = bm.relu(self.bn1(s, self.conv1(s, x)))
- out = self.bn2(s, self.conv2(s, out))
- out += self.shortcut(s, x)
+ def update(self, x):
+ out = bm.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
preact = out
out = bm.relu(out)
if self.is_last:
@@ -77,10 +77,10 @@ def __init__(self, in_planes, planes, stride=1, is_last=False):
)
def update(self, s, x):
- out = bm.relu(self.bn1(s, self.conv1(s, x)))
- out = bm.relu(self.bn2(s, self.conv2(s, out)))
- out = self.bn3(s, self.conv3(s, out))
- out += self.shortcut(s, x)
+ out = bm.relu(self.bn1(self.conv1(x)))
+ out = bm.relu(self.bn2(self.conv2(out)))
+ out = self.bn3(self.conv3(out))
+ out += self.shortcut(x)
preact = out
out = bm.relu(out)
if self.is_last:
@@ -141,21 +141,21 @@ def _make_layer(self, block, planes, num_blocks, stride):
return bp.Sequential(*layers)
def update(self, s, x, is_feat=False, preact=False):
- out = bm.relu(self.bn1(s, self.conv1(s, x)))
+ out = bm.relu(self.bn1(self.conv1(x)))
f0 = out
- out, f1_pre = self.layer1(s, out)
+ out, f1_pre = self.layer1(out)
f1 = out
- out, f2_pre = self.layer2(s, out)
+ out, f2_pre = self.layer2(out)
f2 = out
- out, f3_pre = self.layer3(s, out)
+ out, f3_pre = self.layer3(out)
f3 = out
- out, f4_pre = self.layer4(s, out)
+ out, f4_pre = self.layer4(out)
f4 = out
- # out = self.avgpool(s, out)
+ # out = self.avgpool(out)
# out = out.reshape(128, -1)
out = bm.mean(out, axis=(1, 2))
f5 = out
- out = self.linear(s, out)
+ out = self.linear(out)
if is_feat:
if preact:
return [[f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], out]
@@ -213,8 +213,8 @@ def main():
# loss function
def loss_fun(X, Y, fit=True):
- s = {'fit': fit}
- predictions = net(s, X)
+ bp.share.save(fit=fit)
+ predictions = net(X)
l = bp.losses.cross_entropy_loss(predictions, Y)
n = bm.sum(predictions.argmax(1) == Y)
return l, n
From 7c56adf17b068091a195eaaaa9f850cae5d1df0c Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 21 Jul 2023 10:20:53 +0800
Subject: [PATCH 055/326] upgrade `brainpy.analysis` for new version of
`DynamicalSystem`
---
brainpy/_src/analysis/highdim/slow_points.py | 13 ++++-
.../_src/analysis/lowdim/lowdim_analyzer.py | 13 ++---
.../analysis/lowdim/lowdim_bifurcation.py | 56 +++++++++----------
.../analysis/lowdim/lowdim_phase_plane.py | 32 +++++------
brainpy/_src/analysis/utils/model.py | 11 ++--
5 files changed, 64 insertions(+), 61 deletions(-)
diff --git a/brainpy/_src/analysis/highdim/slow_points.py b/brainpy/_src/analysis/highdim/slow_points.py
index 55a7b3207..3ec96e440 100644
--- a/brainpy/_src/analysis/highdim/slow_points.py
+++ b/brainpy/_src/analysis/highdim/slow_points.py
@@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-
+import inspect
import math
import time
+import warnings
from typing import Callable, Union, Dict, Sequence, Tuple
import jax.numpy as jnp
@@ -18,6 +20,8 @@
from brainpy._src.runners import check_and_format_inputs, _f_ops
from brainpy.errors import AnalyzerError, UnsupportedError
from brainpy.types import ArrayType
+from brainpy._src.deprecations import _input_deprecate_msg
+
__all__ = [
'SlowPointFinder',
@@ -514,7 +518,7 @@ def exclude_outliers(self, tolerance: float = 1e0):
# Compute pairwise distances between all fixed points.
distances = np.asarray(utils.euclidean_distance_jax(self.fixed_points, num_fps))
- # Find second smallest element in each column of the pairwise distance matrix.
+ # Find the second smallest element in each column of the pairwise distance matrix.
# This corresponds to the closest neighbor for each fixed point.
closest_neighbor = np.partition(distances, kth=1, axis=0)[1]
@@ -640,7 +644,12 @@ def _step_func_input(self):
if self._inputs is None:
return
elif callable(self._inputs):
- self._inputs(share.get_shargs())
+ try:
+ ba = inspect.signature(self._inputs).bind(dict())
+ self._inputs(share.get_shargs())
+ warnings.warn(_input_deprecate_msg, UserWarning)
+ except TypeError:
+ self._inputs()
else:
for ops, values in self._inputs['fixed'].items():
for var, data in values:
diff --git a/brainpy/_src/analysis/lowdim/lowdim_analyzer.py b/brainpy/_src/analysis/lowdim/lowdim_analyzer.py
index 4303543b8..f186659e9 100644
--- a/brainpy/_src/analysis/lowdim/lowdim_analyzer.py
+++ b/brainpy/_src/analysis/lowdim/lowdim_analyzer.py
@@ -99,7 +99,8 @@ def __init__(
raise errors.AnalyzerError(f'{key} is not a dynamical variable in {self.model}.')
value = self.target_vars[key]
if value[0] > value[1]:
- raise errors.AnalyzerError(f'The range of variable {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')
+ raise errors.AnalyzerError(
+ f'The range of variable {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')
# fixed variables
# ----------------
@@ -246,7 +247,7 @@ class Num1DAnalyzer(LowDimAnalyzer):
"""
def __init__(self, *args, **kwargs):
- super(Num1DAnalyzer, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
self.x_var = self.target_var_names[0]
if len(self.target_vars) < 1:
raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system '
@@ -407,7 +408,7 @@ class Num2DAnalyzer(Num1DAnalyzer):
"""
def __init__(self, *args, **kwargs):
- super(Num2DAnalyzer, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
if len(self.target_vars) < 2:
raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system '
f'with >= 2 variables. But we got {len(self.target_vars)} '
@@ -1028,7 +1029,7 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
class Num3DAnalyzer(Num2DAnalyzer):
def __init__(self, *args, **kwargs):
- super(Num3DAnalyzer, self).__init__(*args, **kwargs)
+ super().__init__(*args, **kwargs)
if len(self.target_vars) < 3:
raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system '
f'with >= 3 variables. But we got {len(self.target_vars)} '
@@ -1045,7 +1046,3 @@ def F_fz(self):
f = partial(f, **(self.pars_update + self.fixed_vars))
self.analyzed_results[C.F_fz] = jax.jit(f, device=self.jit_device)
return self.analyzed_results[C.F_fz]
-
- def fz_signs(self, pars=(), cache=False):
- xyz = tuple(self.resolutions.values())
- return utils.get_sign2(self.F_fz, *xyz, args=pars)
diff --git a/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py b/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py
index b157adc16..97a8d3b59 100644
--- a/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py
+++ b/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py
@@ -31,13 +31,13 @@ class Bifurcation1D(Num1DAnalyzer):
def __init__(self, model, target_pars, target_vars, fixed_vars=None,
pars_update=None, resolutions=None, options=None):
- super(Bifurcation1D, self).__init__(model=model,
- target_pars=target_pars,
- target_vars=target_vars,
- fixed_vars=fixed_vars,
- pars_update=pars_update,
- resolutions=resolutions,
- options=options)
+ super().__init__(model=model,
+ target_pars=target_pars,
+ target_vars=target_vars,
+ fixed_vars=fixed_vars,
+ pars_update=pars_update,
+ resolutions=resolutions,
+ options=options)
if len(self.target_pars) == 0:
raise ValueError
@@ -146,13 +146,13 @@ class Bifurcation2D(Num2DAnalyzer):
def __init__(self, model, target_pars, target_vars, fixed_vars=None,
pars_update=None, resolutions=None, options=None):
- super(Bifurcation2D, self).__init__(model=model,
- target_pars=target_pars,
- target_vars=target_vars,
- fixed_vars=fixed_vars,
- pars_update=pars_update,
- resolutions=resolutions,
- options=options)
+ super().__init__(model=model,
+ target_pars=target_pars,
+ target_vars=target_vars,
+ fixed_vars=fixed_vars,
+ pars_update=pars_update,
+ resolutions=resolutions,
+ options=options)
if len(self.target_pars) == 0:
raise ValueError
@@ -458,13 +458,13 @@ def __init__(
resolutions=None,
options: dict = None
):
- super(FastSlow1D, self).__init__(model=model,
- target_pars=slow_vars,
- target_vars=fast_vars,
- fixed_vars=fixed_vars,
- pars_update=pars_update,
- resolutions=resolutions,
- options=options)
+ super().__init__(model=model,
+ target_pars=slow_vars,
+ target_vars=fast_vars,
+ fixed_vars=fixed_vars,
+ pars_update=pars_update,
+ resolutions=resolutions,
+ options=options)
# standard integrators
self._std_integrators = dict()
@@ -549,13 +549,13 @@ def __init__(
resolutions=0.1,
options: dict = None
):
- super(FastSlow2D, self).__init__(model=model,
- target_pars=slow_vars,
- target_vars=fast_vars,
- fixed_vars=fixed_vars,
- pars_update=pars_update,
- resolutions=resolutions,
- options=options)
+ super().__init__(model=model,
+ target_pars=slow_vars,
+ target_vars=fast_vars,
+ fixed_vars=fixed_vars,
+ pars_update=pars_update,
+ resolutions=resolutions,
+ options=options)
# standard integrators
self._std_integrators = dict()
for key, intg in self.model.name2integral.items():
diff --git a/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py b/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py
index 7b3527329..b3df8e1ee 100644
--- a/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py
+++ b/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py
@@ -55,13 +55,13 @@ def __init__(self,
if (target_pars is not None) and len(target_pars) > 0:
raise errors.AnalyzerError(f'Phase plane analysis does not support "target_pars". '
f'While we detect "target_pars={target_pars}".')
- super(PhasePlane1D, self).__init__(model=model,
- target_vars=target_vars,
- fixed_vars=fixed_vars,
- target_pars=target_pars,
- pars_update=pars_update,
- resolutions=resolutions,
- **kwargs)
+ super().__init__(model=model,
+ target_vars=target_vars,
+ fixed_vars=fixed_vars,
+ target_pars=target_pars,
+ pars_update=pars_update,
+ resolutions=resolutions,
+ **kwargs)
# utils.output(f'I am {PhasePlane1D.__name__}.')
def plot_vector_field(self, show=False, with_plot=True, with_return=False):
@@ -150,13 +150,13 @@ def __init__(self,
if (target_pars is not None) and len(target_pars) > 0:
raise errors.AnalyzerError(f'Phase plane analysis does not support "target_pars". '
f'While we detect "target_pars={target_pars}".')
- super(PhasePlane2D, self).__init__(model=model,
- target_vars=target_vars,
- fixed_vars=fixed_vars,
- target_pars=target_pars,
- pars_update=pars_update,
- resolutions=resolutions,
- **kwargs)
+ super().__init__(model=model,
+ target_vars=target_vars,
+ fixed_vars=fixed_vars,
+ target_pars=target_pars,
+ pars_update=pars_update,
+ resolutions=resolutions,
+ **kwargs)
@property
def F_vmap_brentq_fy(self):
@@ -251,7 +251,7 @@ def plot_nullcline(self, with_plot=True, with_return=False,
if with_plot:
if x_style is None:
x_style = dict(color='cornflowerblue', alpha=.7, fmt='.')
- line_args = (x_style.pop('fmt'), ) if 'fmt' in x_style else tuple()
+ line_args = (x_style.pop('fmt'),) if 'fmt' in x_style else tuple()
pyplot.plot(x_values_in_fx, y_values_in_fx, *line_args, **x_style, label=f"{self.x_var} nullcline")
# Nullcline of the y variable
@@ -263,7 +263,7 @@ def plot_nullcline(self, with_plot=True, with_return=False,
if with_plot:
if y_style is None:
y_style = dict(color='lightcoral', alpha=.7, fmt='.')
- line_args = (y_style.pop('fmt'), ) if 'fmt' in y_style else tuple()
+ line_args = (y_style.pop('fmt'),) if 'fmt' in y_style else tuple()
pyplot.plot(x_values_in_fy, y_values_in_fy, *line_args, **y_style, label=f"{self.y_var} nullcline")
if with_plot:
diff --git a/brainpy/_src/analysis/utils/model.py b/brainpy/_src/analysis/utils/model.py
index a2c92fc97..6acc3f456 100644
--- a/brainpy/_src/analysis/utils/model.py
+++ b/brainpy/_src/analysis/utils/model.py
@@ -5,6 +5,7 @@
from brainpy._src.math.environment import get_float
from brainpy._src.math.interoperability import as_jax
from brainpy._src.dynsys import DynamicalSystem
+from brainpy._src.context import share
from brainpy._src.runners import DSRunner
from brainpy._src.integrators.base import Integrator
from brainpy._src.integrators.joint_eq import JointEq
@@ -126,16 +127,12 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):
self.integrals = integrals
# runner
- self.runner = DSRunner(self,
- monitors=list(initial_vars.keys()),
- dyn_vars=self.vars().unique(),
- dt=dt,
- progress_bar=False)
+ self.runner = DSRunner(self, monitors=list(initial_vars.keys()), dt=dt, progress_bar=False)
- def update(self, sha):
+ def update(self):
all_vars = list(self.implicit_vars.values())
for key, intg in self.integrals.items():
- self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=sha['dt']))
+ self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=share['dt']))
def __getattr__(self, item):
child_vars = super(TrajectModel, self).__getattribute__('implicit_vars')
From c043281cc06925dade208b9c9afb33c2d933c718 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 21 Jul 2023 10:21:56 +0800
Subject: [PATCH 056/326] upgrade `brainpy.train` for new version of
`DynamicalSystem`
---
brainpy/_src/runners.py | 193 +++++++++++--------------
brainpy/_src/running/runner.py | 1 -
brainpy/_src/train/back_propagation.py | 161 +++++++++------------
brainpy/_src/train/base.py | 12 +-
brainpy/_src/train/offline.py | 81 +++++------
brainpy/_src/train/online.py | 106 ++++++--------
6 files changed, 228 insertions(+), 326 deletions(-)
diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py
index 1bfd9cc61..42b40b88e 100644
--- a/brainpy/_src/runners.py
+++ b/brainpy/_src/runners.py
@@ -1,10 +1,10 @@
# -*- coding: utf-8 -*-
-
+import functools
+import inspect
import time
import warnings
from collections.abc import Iterable
-from functools import partial
-from typing import Dict, Union, Sequence, Callable, Tuple, Optional
+from typing import Dict, Union, Sequence, Callable, Tuple, Optional, Any
import jax
import jax.numpy as jnp
@@ -14,13 +14,12 @@
from jax.tree_util import tree_map, tree_flatten
from brainpy import math as bm, tools
-from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.context import share
+from brainpy._src.deprecations import _input_deprecate_msg
+from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.running.runner import Runner
-from brainpy.check import serialize_kwargs
from brainpy.errors import RunningError
-from brainpy.types import ArrayType, Output, Monitor
-
+from brainpy.types import Output, Monitor
__all__ = [
'DSRunner',
@@ -30,6 +29,16 @@
SUPPORTED_INPUT_TYPE = ['fix', 'iter', 'func']
+def _call_fun_with_share(f, *args, **kwargs):
+ try:
+ sha = share.get_shargs()
+ inspect.signature(f).bind(sha, *args, **kwargs)
+ warnings.warn(_input_deprecate_msg, UserWarning)
+ return f(sha, *args, **kwargs)
+ except TypeError:
+ return f(*args, **kwargs)
+
+
def _is_brainpy_array(x):
return isinstance(x, bm.Array)
@@ -78,7 +87,6 @@ def check_and_format_inputs(host, inputs):
# 2. get targets and attributes
# ---------
inputs_which_found_target = []
- inputs_not_found_target = []
# checking 1: absolute access
# Check whether the input target node is accessible,
@@ -101,22 +109,6 @@ def check_and_format_inputs(host, inputs):
f'specify variable of the target, but we got {key}.')
inputs_which_found_target.append((real_target,) + tuple(one_input[1:]))
- # checking 2: relative access
- # Check whether the input target node is accessible
- # and check whether the target node has the attribute
- # if len(inputs_not_found_target):
- # nodes = host.nodes(method='relative', level=-1, include_self=True)
- # for one_input in inputs_not_found_target:
- # splits = one_input[0].split('.')
- # target, key = '.'.join(splits[:-1]), splits[-1]
- # if target not in nodes:
- # raise RunningError(f'Input target "{target}" is not defined in {host}.')
- # real_target = nodes[target]
- # if not hasattr(real_target, key):
- # raise RunningError(f'Input target key "{key}" is not defined in {real_target}.')
- # real_target = getattr(real_target, key)
- # inputs_which_found_target.append((real_target,) + tuple(one_input[1:]))
-
# 3. format inputs
# ---------
formatted_inputs = []
@@ -257,7 +249,7 @@ class DSRunner(Runner):
- A list of string with index specification. Like ``monitors=[('a', 1), ('b', [1,3,5]), 'c']``
- A dict with the explicit monitor target, like: ``monitors={'a': model.spike, 'b': model.V}``
- A dict with the index specification, like: ``monitors={'a': (model.spike, 0), 'b': (model.V, [1,2])}``
- - A dict with the callable function, like ``monitors={'a': lambda tdi: model.spike[:5]}``
+ - A dict with the callable function, like ``monitors={'a': lambda: model.spike[:5]}``
.. versionchanged:: 2.3.1
``fun_monitors`` are merged into ``monitors``.
@@ -266,8 +258,8 @@ class DSRunner(Runner):
The dict ``key`` should be a string for the later retrieval by ``runner.mon[key]``.
The dict ``value`` should be a callable function which receives two arguments: ``t`` and ``dt``.
.. code-block::
- fun_monitors = {'spike': lambda tdi: model.spike[:10],
- 'V10': lambda tdi: model.V[10]}
+ fun_monitors = {'spike': lambda: model.spike[:10],
+ 'V10': lambda: model.V[10]}
.. deprecated:: 2.3.1
Will be removed since version 2.4.0.
@@ -334,17 +326,16 @@ def __init__(
if not isinstance(target, DynamicalSystem):
raise RunningError(f'"target" must be an instance of {DynamicalSystem.__name__}, '
f'but we got {type(target)}: {target}')
- super(DSRunner, self).__init__(target=target,
- monitors=monitors,
- fun_monitors=fun_monitors,
- jit=jit,
- progress_bar=progress_bar,
- dyn_vars=dyn_vars,
- numpy_mon_after_run=numpy_mon_after_run)
+ super().__init__(target=target,
+ monitors=monitors,
+ fun_monitors=fun_monitors,
+ jit=jit,
+ progress_bar=progress_bar,
+ dyn_vars=dyn_vars,
+ numpy_mon_after_run=numpy_mon_after_run)
# t0 and i0
self.i0 = 0
- self._t0 = t0
self.t0 = t0
if data_first_axis is None:
data_first_axis = 'B' if isinstance(self.target.mode, bm.BatchingMode) else 'T'
@@ -369,7 +360,7 @@ def __init__(
self._inputs = check_and_format_inputs(host=target, inputs=inputs)
# run function
- self._f_predict_compiled = dict()
+ self._jit_step_func_predict = bm.jit(self._step_func_predict, static_argnames=['shared_args'])
# monitors
self._memory_efficient = memory_efficient
@@ -388,15 +379,15 @@ def __repr__(self):
def reset_state(self):
"""Reset state of the ``DSRunner``."""
self.i0 = 0
- self.t0 = self._t0
+ self.t0 = self.t0
def predict(
self,
duration: float = None,
- inputs: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]] = None,
+ inputs: Any = None,
reset_state: bool = False,
- shared_args: Dict = None,
eval_time: bool = False,
+ shared_args: Dict = None,
# deprecated
inputs_are_batching: bool = None,
@@ -431,10 +422,10 @@ def predict(
Will be removed after version 2.4.0.
reset_state: bool
Whether reset the model states.
- shared_args: optional, dict
- The shared arguments across different layers.
eval_time: bool
Whether ro evaluate the running time.
+ shared_args: optional, dict
+ The shared arguments across different layers.
Returns
-------
@@ -469,13 +460,7 @@ def predict(
self.reset_state()
# shared arguments and inputs
- if shared_args is None:
- shared_args = dict()
- shared_args['fit'] = shared_args.get('fit', False)
- shared = tools.DotDict(i=np.arange(num_step, dtype=bm.int_))
- shared['t'] = shared['i'] * self.dt
- shared['i'] += self.i0
- shared['t'] += self.t0
+ indices = np.arange(self.i0, self.i0 + num_step, dtype=bm.int_)
if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B':
inputs = tree_map(lambda x: jnp.moveaxis(x, 0, 1), inputs)
@@ -492,8 +477,11 @@ def predict(
# running
if eval_time:
t0 = time.time()
- with jax.disable_jit(not self.jit['predict']):
- outputs, hists = self._predict(xs=(shared['t'], shared['i'], inputs), shared_args=shared_args)
+ if inputs is None:
+ inputs = tuple()
+ if not isinstance(inputs, (tuple, list)):
+ inputs = (inputs,)
+ outputs, hists = self._predict(indices, *inputs, shared_args=shared_args)
if eval_time:
running_time = time.time() - t0
@@ -503,17 +491,18 @@ def predict(
# post-running for monitors
if self._memory_efficient:
- self.mon['ts'] = shared['t'] + self.dt
+ self.mon['ts'] = indices * self.dt + self.t0
for key in self.mon.var_names:
self.mon[key] = np.asarray(self.mon[key])
else:
- hists['ts'] = shared['t'] + self.dt
+ hists['ts'] = indices * self.dt + self.t0
if self.numpy_mon_after_run:
hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.Array))
+ else:
+ hists['ts'] = bm.as_jax(hists['ts'])
for key in hists.keys():
self.mon[key] = hists[key]
self.i0 += num_step
- self.t0 += (num_step * self.dt if duration is None else duration)
return outputs if not eval_time else (running_time, outputs)
def run(self, *args, **kwargs) -> Union[Output, Tuple[float, Output]]:
@@ -526,17 +515,12 @@ def __call__(self, *args, **kwargs) -> Union[Output, Tuple[float, Output]]:
"""
return self.predict(*args, **kwargs)
- def _predict(
- self,
- xs: Sequence,
- shared_args: Dict = None,
- ) -> Union[Output, Monitor]:
+ def _predict(self, indices, *xs, shared_args=None) -> Union[Output, Monitor]:
"""Predict the output according to the inputs.
Parameters
----------
xs: sequence
- Must be a tuple/list of data, including `(times, indices, inputs)`.
If `inputs` is not None, it should be a tensor with the shape of
:math:`(num_time, ...)`.
shared_args: optional, dict
@@ -547,18 +531,21 @@ def _predict(
outputs, hists
A tuple of pair of (outputs, hists).
"""
- _predict_func = self._get_f_predict(shared_args)
- outs_and_mons = _predict_func(xs)
+ if shared_args is None:
+ shared_args = dict()
+ shared_args = tools.DotDict(shared_args)
+
+ outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args)
if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B':
outs_and_mons = tree_map(lambda x: jnp.moveaxis(x, 0, 1) if x.ndim >= 2 else x,
outs_and_mons)
return outs_and_mons
- def _step_func_monitor(self, shared):
+ def _step_func_monitor(self):
res = dict()
for key, val in self._monitors.items():
if callable(val):
- res[key] = val(shared)
+ res[key] = _call_fun_with_share(val)
else:
(variable, idx) = val
if idx is None:
@@ -567,21 +554,21 @@ def _step_func_monitor(self, shared):
res[key] = variable[bm.as_jax(idx)]
return res
- def _step_func_input(self, shared):
+ def _step_func_input(self):
if self._fun_inputs is not None:
- self._fun_inputs(shared)
+ self._fun_inputs(share.get_shargs())
if callable(self._inputs):
- self._inputs(shared)
+ _call_fun_with_share(self._inputs)
else:
for ops, values in self._inputs['fixed'].items():
for var, data in values:
_f_ops(ops, var, data)
for ops, values in self._inputs['array'].items():
for var, data in values:
- _f_ops(ops, var, data[shared['i']])
+ _f_ops(ops, var, data[share['i']])
for ops, values in self._inputs['functional'].items():
for var, data in values:
- _f_ops(ops, var, data(shared))
+ _f_ops(ops, var, _call_fun_with_share(data))
for ops, values in self._inputs['iterated'].items():
for var, data in values:
_f_ops(ops, var, next(data))
@@ -628,25 +615,24 @@ def _step_mon_on_cpu(self, args, transforms):
for key, val in args.items():
self.mon[key].append(val)
- def _step_func_predict(self, shared_args, t, i, x):
+ def _step_func_predict(self, i, *x, shared_args=None):
# input step
- shared = tools.DotDict(t=t, i=i, dt=self.dt)
- shared.update(shared_args)
- share.save(**shared)
- self._step_func_input(shared)
+ if shared_args is not None:
+ assert isinstance(shared_args, dict)
+ share.save(**shared_args)
+ share.save(t=self.t0 + i * self.dt, i=i, dt=self.dt)
+ self._step_func_input()
# dynamics update step
- args = () if x is None else (x,)
- out = self.target(*args)
+ out = self.target(*x)
# monitor step
- shared['t'] += self.dt
- mon = self._step_func_monitor(shared)
+ mon = self._step_func_monitor()
# finally
if self.progress_bar:
id_tap(lambda *arg: self._pbar.update(), ())
- share.clear_shargs()
+ # share.clear_shargs()
self.target.clear_input()
if self._memory_efficient:
@@ -655,40 +641,23 @@ def _step_func_predict(self, shared_args, t, i, x):
else:
return out, mon
- def _get_f_predict(self, shared_args: Dict = None):
- if shared_args is None:
- shared_args = dict()
-
- shared_kwargs_str = serialize_kwargs(shared_args)
- if shared_kwargs_str not in self._f_predict_compiled:
-
- if self._memory_efficient:
- _jit_step = bm.jit(partial(self._step_func_predict, shared_args))
-
- def run_func(all_inputs):
- outs = None
- times, indices, xs = all_inputs
- for i in range(times.shape[0]):
- out, _ = _jit_step(times[i], indices[i], tree_map(lambda a: a[i], xs))
- if outs is None:
- outs = tree_map(lambda a: [], out)
- outs = tree_map(lambda a, o: o.append(a), out, outs)
- outs = tree_map(lambda a: bm.as_jax(a), outs)
- return outs, None
-
+ def _fun_predict(self, indices, *inputs, shared_args=None):
+ if self._memory_efficient:
+ if self.jit['predict']:
+ run_fun = self._jit_step_func_predict
else:
- step = partial(self._step_func_predict, shared_args)
+ run_fun = self._step_func_predict
- def run_func(all_inputs):
- return bm.for_loop(step, all_inputs, jit=self.jit['predict'])
-
- self._f_predict_compiled[shared_kwargs_str] = run_func
-
- return self._f_predict_compiled[shared_kwargs_str]
-
- def __del__(self):
- if hasattr(self, '_f_predict_compiled'):
- for key in tuple(self._f_predict_compiled.keys()):
- self._f_predict_compiled.pop(key)
- super(DSRunner, self).__del__()
+ outs = None
+ for i in range(indices.shape[0]):
+ out, _ = run_fun(indices[i], *tree_map(lambda a: a[i], inputs), shared_args=shared_args)
+ if outs is None:
+ outs = tree_map(lambda a: [], out)
+ outs = tree_map(lambda a, o: o.append(a), out, outs)
+ outs = tree_map(lambda a: bm.as_jax(a), outs)
+ return outs, None
+ else:
+ return bm.for_loop(functools.partial(self._step_func_predict, shared_args=shared_args),
+ (indices, *inputs),
+ jit=self.jit['predict'])
diff --git a/brainpy/_src/running/runner.py b/brainpy/_src/running/runner.py
index 4245cb2d7..2a2de3d3f 100644
--- a/brainpy/_src/running/runner.py
+++ b/brainpy/_src/running/runner.py
@@ -127,7 +127,6 @@ def __init__(
# dynamical changed variables
self._dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
- self.register_implicit_vars(self._dyn_vars)
# numpy mon after run
self.numpy_mon_after_run = numpy_mon_after_run
diff --git a/brainpy/_src/train/back_propagation.py b/brainpy/_src/train/back_propagation.py
index 38ac3f848..806b68693 100644
--- a/brainpy/_src/train/back_propagation.py
+++ b/brainpy/_src/train/back_propagation.py
@@ -2,7 +2,6 @@
import time
from collections.abc import Iterable
-from functools import partial
from typing import Union, Dict, Callable, Sequence, Optional
import jax.numpy as jnp
@@ -10,14 +9,13 @@
from jax.tree_util import tree_map
from tqdm import tqdm
+from brainpy import tools
import brainpy.losses as losses
import brainpy.math as bm
-from brainpy import tools, optim
-from brainpy._src.dynsys import DynamicalSystem
+from brainpy import optim
from brainpy._src.context import share
-from brainpy._src.math.object_transform.base import BrainPyObject
+from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.running import constants as c
-from brainpy.check import serialize_kwargs
from brainpy.errors import UnsupportedError, NoLongerSupportError
from brainpy.types import ArrayType, Output
from ._utils import msg
@@ -83,8 +81,7 @@ def __init__(
**kwargs,
):
- super(BPTrainer, self).__init__(target=target,
- **kwargs)
+ super().__init__(target=target, **kwargs)
if shuffle_data is not None:
raise NoLongerSupportError(
@@ -137,8 +134,9 @@ def __init__(
self._detailed_test_metrics = dict()
# functions
- self._f_loss_compiled = dict()
- self._f_grad_compiled = dict()
+ self._jit_step_func_grad = bm.jit(self._step_func_grad, static_argnums=(0,))
+ self._jit_step_func_loss = bm.jit(self._step_func_loss, static_argnums=(0,))
+ self._jit_step_func_fit = bm.jit(self._step_func_fit, static_argnums=(0,))
def __repr__(self):
name = self.__class__.__name__
@@ -230,6 +228,11 @@ def fit(
Please set batch size in your dataset.
"""
+ if shared_args is None:
+ shared_args = dict()
+ shared_args['fit'] = shared_args.get('fit', True)
+ shared_args = tools.DotDict(shared_args)
+
if batch_size is not None:
raise NoLongerSupportError('Please set batch size in your data. '
'Specifically, make an iterable dataset '
@@ -246,7 +249,7 @@ def fit(
if shared_args is None:
shared_args = dict()
- shared_args['fit'] = shared_args.get('fit', False)
+ shared_args['fit'] = shared_args.get('fit', True)
true_progress_bar = self.progress_bar
self.progress_bar = False
@@ -277,7 +280,7 @@ def fit(
self.reset_state()
# training
- res = self._get_f_train(shared_args)(x, y)
+ res = self.f_train(shared_args, x, y)
# loss
fit_epoch_metric['loss'].append(res[0])
@@ -355,7 +358,7 @@ def fit(
self.reset_state()
# testing
- res = self._get_f_loss(shared_args)(x, y)
+ res = self.f_loss(shared_args, x, y)
# loss
if self.loss_has_aux:
@@ -426,61 +429,32 @@ def fit(
self._detailed_test_metrics = {k: np.asarray(v) for k, v in detailed_test_metric.items()}
self.progress_bar = true_progress_bar
- def _get_f_loss(self, shared_args=None, jit=True) -> Callable:
- """Get loss function."""
- if shared_args is None:
- shared_args = dict()
- shared_args2 = {k: v for k, v in shared_args.items()}
- shared_args2['_local_jit_'] = jit
- shared_args_str = serialize_kwargs(shared_args2)
- if shared_args_str not in self._f_loss_compiled:
- self._f_loss_compiled[shared_args_str] = partial(self._step_func_loss, shared_args)
- if self.jit[c.LOSS_PHASE] and jit:
- self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str])
- return self._f_loss_compiled[shared_args_str]
-
- def _get_f_grad(self, shared_args=None) -> Callable:
- """Get gradient function."""
- shared_args_str = serialize_kwargs(shared_args)
- if shared_args_str not in self._f_grad_compiled:
- _f_loss_internal = self._get_f_loss(shared_args, jit=False)
- dyn_vars = self.target.vars()
- dyn_vars.update(self._dyn_vars)
- tran_vars = dyn_vars.subset(bm.TrainVar).unique()
- grad_f = bm.grad(_f_loss_internal,
- grad_vars=tran_vars,
- return_value=True,
- has_aux=self.loss_has_aux)
- self._f_grad_compiled[shared_args_str] = grad_f
- return self._f_grad_compiled[shared_args_str]
-
- def _get_f_train(self, shared_args=None) -> Callable:
- """Get training function."""
- if shared_args is None: shared_args = dict()
- if not isinstance(shared_args, dict):
- raise ValueError(f'Only supports dict for "shared_args". '
- f'But got {type(shared_args)}: {shared_args}')
-
- shared_args_str = serialize_kwargs(shared_args)
- if shared_args_str not in self._f_fit_compiled:
- self._f_fit_compiled[shared_args_str] = partial(self._step_func_fit, shared_args)
- if self.jit[c.FIT_PHASE]:
- dyn_vars = self.target.vars()
- dyn_vars.update(self.optimizer.vars())
- if isinstance(self._loss_func, BrainPyObject):
- dyn_vars.update(self._loss_func)
- dyn_vars.update(self._dyn_vars)
- dyn_vars.update(self.vars(level=0))
- dyn_vars = dyn_vars.unique()
- self._f_fit_compiled[shared_args_str] = bm.jit(self._f_fit_compiled[shared_args_str])
- return self._f_fit_compiled[shared_args_str]
+ def _step_func_grad(self, shared_args, inputs, targets):
+ tran_vars = self.target.train_vars().unique()
+ grad_f = bm.grad(self._step_func_loss,
+ grad_vars=tran_vars,
+ return_value=True,
+ has_aux=self.loss_has_aux)
+ return grad_f(shared_args, inputs, targets)
def _step_func_loss(self, shared_args, inputs, targets):
raise NotImplementedError
+ @property
+ def f_loss(self):
+ return self._jit_step_func_loss if self.jit[c.LOSS_PHASE] else self._step_func_loss
+
def _step_func_fit(self, shared_args, inputs, targets):
raise NotImplementedError
+ @property
+ def f_train(self):
+ return self._jit_step_func_fit if self.jit[c.FIT_PHASE] else self._step_func_fit
+
+ @property
+ def f_grad(self):
+ return self._jit_step_func_grad if self.jit[c.FIT_PHASE] else self._step_func_grad
+
class BPTT(BPTrainer):
"""The trainer implementing the back-propagation through time (BPTT)
@@ -528,18 +502,17 @@ def loss_fun(predicts, targets):
def _step_func_loss(self, shared_args, inputs, targets):
num_step = self._get_input_time_step(xs=inputs)
- indices = jnp.arange(num_step, dtype=bm.int_)
- times = indices * self.dt + self.t0
- indices = indices + self.i0
+ indices = np.arange(self.i0, self.i0 + num_step, dtype=np.int_)
if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B':
inputs = tree_map(lambda x: bm.moveaxis(x, 0, 1), inputs, is_leaf=lambda x: isinstance(x, bm.Array))
- inputs = (times, indices, inputs)
- outs, mons = self._predict(xs=inputs, shared_args=shared_args)
+ if not isinstance(inputs, (tuple, list)):
+ inputs = (inputs,)
+ outs, mons = self._predict(indices, *inputs, shared_args=shared_args)
predicts = (outs, mons) if len(mons) > 0 else outs
return self._loss_func(predicts, targets)
def _step_func_fit(self, shared_args, inputs, targets):
- res = self._get_f_grad(shared_args)(inputs, targets)
+ res = self.f_grad(shared_args, inputs, targets)
self.optimizer.update(res[0])
return res[1:]
@@ -554,49 +527,43 @@ class BPFF(BPTrainer):
"""
def _step_func_loss(self, shared_args, inputs, targets):
- outputs, mon = self._get_f_predict(shared_args)(inputs)
+ if not isinstance(inputs, (tuple, list)):
+ inputs = (inputs,)
+ outputs, mon = self._step_func_predict(*inputs, shared_args=shared_args)
outs = (outputs, mon) if len(mon) > 0 else outputs
loss = self._loss_func(outs, targets)
return loss
def _step_func_fit(self, shared_args, inputs, targets):
- res = self._get_f_grad(shared_args)(inputs, targets)
+ res = self.f_grad(shared_args, inputs, targets)
self.optimizer.update(res[0])
return res[1:]
- def _step_func_predict(self, shared, x=None):
- assert self.data_first_axis == 'B', f'There is no time dimension when using the trainer {self.__class__.__name__}.'
- for k, v in shared.items():
- share.save(k, v)
+ def _step_func_predict(self, *x, shared_args=None):
+ assert self.data_first_axis == 'B', (f'There is no time dimension when '
+ f'using the trainer {self.__class__.__name__}.')
+ if shared_args is not None:
+ assert isinstance(shared_args, dict)
+ share.save(**shared_args)
+ share.save(dt=self.dt)
# input step
self.target.clear_input()
- self._step_func_input(shared)
+ self._step_func_input()
# dynamics update step
- args = () if x is None else (x, )
- out = self.target(*args)
+ out = self.target(*x)
# monitor step
- mon = self._step_func_monitor(shared)
- share.clear_shargs()
+ mon = self._step_func_monitor()
+ # share.clear_shargs()
return out, mon
- def _get_f_predict(self, shared_args: Dict = None, jit: bool = True):
- if shared_args is None:
- shared_args = tools.DotDict()
- if not isinstance(shared_args, dict):
- raise ValueError(f'"shared_args" must be a dict, but got {type(shared_args)}')
-
- shared_args2 = {k: v for k, v in shared_args.items()}
- shared_args2['_local_jit_'] = jit
- shared_args_str = serialize_kwargs(shared_args)
- if shared_args_str not in self._f_predict_compiled:
-
- self._f_predict_compiled[shared_args_str] = partial(self._step_func_predict, shared_args)
- if self.jit[c.PREDICT_PHASE] and jit:
- self._f_predict_compiled[shared_args_str] = bm.jit(self._f_predict_compiled[shared_args_str])
- return self._f_predict_compiled[shared_args_str]
+ def _fun_predict(self, *inputs, shared_args=None):
+ if self.jit['predict']:
+ return self._jit_step_func_predict(*inputs, shared_args=shared_args)
+ else:
+ return self._step_func_predict(*inputs, shared_args=shared_args)
def predict(
self,
@@ -628,8 +595,10 @@ def predict(
output: ArrayType, dict
The model output.
"""
- if shared_args is None: shared_args = dict()
+ if shared_args is None:
+ shared_args = dict()
shared_args['fit'] = shared_args.get('fit', False)
+ shared_args = tools.DotDict(shared_args)
# reset the model states
if reset_state:
@@ -639,8 +608,10 @@ def predict(
for key in self.mon.var_names:
self.mon[key] = [] # reshape the monitor items
# prediction
+ if not isinstance(inputs, (tuple, list)):
+ inputs = (inputs,)
if eval_time: t0 = time.time()
- outs, hists = self._predict(xs=inputs, shared_args=shared_args)
+ outs, hists = self._fun_predict(*inputs, shared_args=shared_args)
if eval_time: t1 = time.time()
# post-running for monitors
for key in hists.keys():
@@ -649,5 +620,3 @@ def predict(
for key in hists.keys():
self.mon[key] = np.asarray(self.mon[key])
return (t1 - t0, outs) if eval_time else outs
-
-
diff --git a/brainpy/_src/train/base.py b/brainpy/_src/train/base.py
index eb19d24d1..97e20a384 100644
--- a/brainpy/_src/train/base.py
+++ b/brainpy/_src/train/base.py
@@ -40,7 +40,7 @@ def __init__(
target: DynamicalSystem,
**kwargs
):
- super(DSTrainer, self).__init__(target=target, **kwargs)
+ super().__init__(target=target, **kwargs)
if not isinstance(self.target.mode, bm.BatchingMode):
raise NoLongerSupportError(f'''
@@ -59,12 +59,9 @@ def __init__(
self.jit[c.PREDICT_PHASE] = self._origin_jit.get(c.PREDICT_PHASE, True)
self.jit[c.FIT_PHASE] = self._origin_jit.get(c.FIT_PHASE, True)
- # training function
- self._f_fit_compiled = dict()
-
def predict(
self,
- inputs: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]],
+ inputs: Any,
reset_state: bool = False,
shared_args: Optional[Dict] = None,
eval_time: bool = False
@@ -77,10 +74,10 @@ def predict(
The input values.
reset_state: bool
Reset the target state before running.
- shared_args: dict
- The shared arguments across nodes.
eval_time: bool
Whether we evaluate the running time or not?
+ shared_args: dict
+ The shared arguments across nodes.
Returns
-------
@@ -90,7 +87,6 @@ def predict(
if shared_args is None:
shared_args = dict()
shared_args['fit'] = shared_args.get('fit', False)
-
return super().predict(inputs=inputs,
reset_state=reset_state,
shared_args=shared_args,
diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py
index fc4a4efd8..ab1521a36 100644
--- a/brainpy/_src/train/offline.py
+++ b/brainpy/_src/train/offline.py
@@ -1,15 +1,17 @@
# -*- coding: utf-8 -*-
-from typing import Dict, Sequence, Union, Callable
+from typing import Dict, Sequence, Union, Callable, Any
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap
import brainpy.math as bm
+from brainpy import tools
+from brainpy._src.context import share
from brainpy._src.dynsys import DynamicalSystem
+from brainpy._src.runners import _call_fun_with_share
from brainpy.algorithms.offline import get, RidgeRegression, OfflineAlgorithm
-from brainpy.check import serialize_kwargs
from brainpy.errors import NoImplementationError
from brainpy.types import ArrayType, Output
from ._utils import format_ys
@@ -56,7 +58,7 @@ def __init__(
):
self._true_numpy_mon_after_run = kwargs.get('numpy_mon_after_run', True)
kwargs['numpy_mon_after_run'] = False
- super(OfflineTrainer, self).__init__(target=target, **kwargs)
+ super().__init__(target=target, **kwargs)
# get all trainable nodes
nodes = self.target.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique()
@@ -83,6 +85,8 @@ def __init__(
# set the training method
for node in self.train_nodes:
node.offline_fit_by = fit_method
+ # training function
+ self._jit_fun_train = bm.jit(self._fun_train, static_argnames=['shared_args'])
def __repr__(self):
name = self.__class__.__name__
@@ -92,7 +96,7 @@ def __repr__(self):
def predict(
self,
- inputs: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]],
+ inputs: Any,
reset_state: bool = False,
shared_args: Dict = None,
eval_time: bool = False
@@ -108,20 +112,18 @@ def predict(
The input values.
reset_state: bool
Reset the target state before running.
- shared_args: dict
- The shared arguments across nodes.
eval_time: bool
Whether we evaluate the running time or not?
+ shared_args: dict
+ The shared arguments across nodes.
Returns
-------
output: ArrayType
The running output.
"""
- outs = super(OfflineTrainer, self).predict(inputs=inputs,
- reset_state=reset_state,
- shared_args=shared_args,
- eval_time=eval_time)
+ outs = super().predict(inputs=inputs, reset_state=reset_state,
+ eval_time=eval_time, shared_args=shared_args)
for node in self.train_nodes:
node.fit_record.clear()
return outs
@@ -152,8 +154,10 @@ def fit(
shared_args: dict
The shared keyword arguments for the target models.
"""
- if shared_args is None: shared_args = dict()
+ if shared_args is None:
+ shared_args = dict()
shared_args['fit'] = shared_args.get('fit', True)
+ shared_args = tools.DotDict(shared_args)
# checking training and testing data
if not isinstance(train_data, (list, tuple)):
@@ -167,6 +171,7 @@ def fit(
xs, ys = train_data
# prediction, get all needed data
+ shared_args['fit'] = shared_args.get('fit', False)
outs = self.predict(inputs=xs, reset_state=reset_state, shared_args=shared_args)
# check target data
@@ -182,7 +187,9 @@ def fit(
for node in self.train_nodes:
key = f'{node.name}-fit_record'
monitor_data[key] = self.mon.get(key)
- self._get_f_train(shared_args)(monitor_data, ys)
+ run_fun = self._jit_fun_train if self.jit['fit'] else self._fun_train
+ shared_args['fit'] = True
+ run_fun(monitor_data, ys, shared_args=shared_args)
del monitor_data
# close the progress bar
@@ -199,19 +206,14 @@ def fit(
return outs
- def _get_f_train(self, shared_args: Dict = None) -> Callable:
- """Get training function."""
- shared_args = dict() if shared_args is None else shared_args
- shared_kwargs_str = serialize_kwargs(shared_args)
- if shared_kwargs_str not in self._f_fit_compiled:
- self._f_fit_compiled[shared_kwargs_str] = (
- self._fun_train
- if self.jit['fit'] else
- bm.jit(self._fun_train)
- )
- return self._f_fit_compiled[shared_kwargs_str]
-
- def _fun_train(self, monitor_data: Dict[str, ArrayType], target_data: Dict[str, ArrayType]):
+ def _fun_train(self,
+ monitor_data: Dict[str, ArrayType],
+ target_data: Dict[str, ArrayType],
+ shared_args: Dict = None):
+ if shared_args is None:
+ shared_args = dict()
+ share.save(**shared_args)
+
for node in self.train_nodes:
fit_record = monitor_data[f'{node.name}-fit_record']
targets = target_data[node.name]
@@ -219,18 +221,18 @@ def _fun_train(self, monitor_data: Dict[str, ArrayType], target_data: Dict[str,
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())
- def _step_func_monitor(self, shared):
+ def _step_func_monitor(self):
res = dict()
for key, val in self._monitors.items():
if callable(val):
- res[key] = val(shared)
+ res[key] = _call_fun_with_share(val)
else:
(variable, idx) = val
if idx is None:
res[key] = variable.value
else:
res[key] = variable[bm.asarray(idx)]
- if shared.get('fit', False):
+ if share.load('fit'):
for node in self.train_nodes:
res[f'{node.name}-fit_record'] = node.fit_record
return res
@@ -238,8 +240,8 @@ def _step_func_monitor(self, shared):
def _check_interface(self):
for node in self.train_nodes:
if not hasattr(node, 'offline_fit'):
- raise NoImplementationError(
- f'''
+ raise NoImplementationError(
+ f'''
The node
{node}
@@ -248,20 +250,7 @@ def _check_interface(self):
However, it does not implement the required training
interface "offline_fit()" function.
'''
- )
- # if hasattr(node.offline_init, 'not_customized'):
- # if node.offline_init.not_customized:
- # raise NoImplementationError(
- # f'''
- # The node
- #
- # {node}
- #
- # is set to be computing mode of {bm.training_mode} with {self.__class__.__name__}.
- # However, it does not implement the required training
- # interface "offline_init()" function.
- # '''
- # )
+ )
class RidgeTrainer(OfflineTrainer):
@@ -278,6 +267,4 @@ class RidgeTrainer(OfflineTrainer):
"""
def __init__(self, target, alpha=1e-7, **kwargs):
- super(RidgeTrainer, self).__init__(target=target,
- fit_method=dict(name='ridge', alpha=alpha),
- **kwargs)
+ super().__init__(target=target, fit_method=dict(name='ridge', alpha=alpha), **kwargs)
diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py
index 837e1df08..08214e7d7 100644
--- a/brainpy/_src/train/online.py
+++ b/brainpy/_src/train/online.py
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
-
-from functools import partial
-from typing import Dict, Sequence, Union, Callable, Tuple
+import functools
+from typing import Dict, Sequence, Union, Callable
import numpy as np
import tqdm.auto
@@ -9,10 +8,10 @@
from jax.tree_util import tree_map
from brainpy import math as bm, tools
-from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.context import share
+from brainpy._src.dynsys import DynamicalSystem
+from brainpy._src.runners import _call_fun_with_share
from brainpy.algorithms.online import get, OnlineAlgorithm, RLS
-from brainpy.check import serialize_kwargs
from brainpy.errors import NoImplementationError
from brainpy.types import ArrayType, Output
from ._utils import format_ys
@@ -58,7 +57,7 @@ def __init__(
fit_method: Union[OnlineAlgorithm, Callable, Dict, str] = None,
**kwargs
):
- super(OnlineTrainer, self).__init__(target=target, **kwargs)
+ super().__init__(target=target, **kwargs)
# get all trainable nodes
nodes = self.target.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique()
@@ -145,6 +144,7 @@ def fit(
) -> Output:
if shared_args is None: shared_args = dict()
shared_args['fit'] = shared_args.get('fit', True)
+ shared_args = tools.DotDict(shared_args)
# checking training and testing data
if not isinstance(train_data, (list, tuple)):
@@ -166,11 +166,8 @@ def fit(
# format input/target data
ys = format_ys(self, ys)
num_step = self._get_input_time_step(xs=xs)
- shared = tools.DotDict(i=bm.arange(num_step, dtype=bm.int_).value)
- shared['t'] = shared['i'] * self.dt
- shared['t'] += self.t0
- shared['i'] += self.i0
+ indices = np.arange(self.i0, num_step + self.i0, dtype=np.int_)
if self.data_first_axis == 'B':
xs = tree_map(lambda x: bm.moveaxis(x, 0, 1),
xs,
@@ -189,35 +186,33 @@ def fit(
self._pbar.set_description(f"Train {num_step} steps: ", refresh=True)
# prediction
- outs, hists = self._fit(tix=(shared['t'], shared['i'], xs), ys=ys, shared_args=shared_args)
+ xs = (xs, ) if not isinstance(xs, (tuple, list)) else xs
+ outs, hists = self._fit(indices, xs=xs, ys=ys, shared_args=shared_args)
# close the progress bar
if self.progress_bar:
self._pbar.close()
# post-running for monitors
- hists['ts'] = shared['t'] + self.dt
if self.numpy_mon_after_run:
hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.Array))
for key in hists.keys():
self.mon[key] = hists[key]
- self.i0 += shared['t'].shape[0]
- self.t0 += num_step * self.dt
+ self.i0 += num_step
return outs
- def _fit(
- self,
- tix: Tuple,
- ys: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]],
- shared_args: Dict = None,
- ):
+ def _fit(self,
+ indices: ArrayType,
+ xs: Sequence,
+ ys: Dict[str, ArrayType],
+ shared_args: Dict = None):
"""Predict the output according to the inputs.
Parameters
----------
- tix: tuple
- Each tensor should have the shape of `(num_time, num_batch, num_feature)`.
- ys: ArrayType
+ indices: ArrayType
+ The running indices.
+ ys: dict
Each tensor should have the shape of `(num_time, num_batch, num_feature)`.
shared_args: optional, dict
The shared keyword arguments.
@@ -227,41 +222,28 @@ def _fit(
outputs, hists
A tuple of pair of (outputs, hists).
"""
- _fit_func = self._get_fit_func(shared_args)
- hists = _fit_func(tix + (ys,))
+ hists = bm.for_loop(functools.partial(self._step_func_fit, shared_args=shared_args),
+ (indices, xs, ys),
+ jit=self.jit['fit'])
hists = tree_map(lambda x: bm.moveaxis(x, 0, 1),
hists,
is_leaf=lambda x: isinstance(x, bm.Array))
return hists
- def _get_fit_func(self, shared_args: Dict = None):
- if shared_args is None: shared_args = dict()
- shared_kwargs_str = serialize_kwargs(shared_args)
- if shared_kwargs_str not in self._f_fit_compiled:
- @bm.jit
- def run_func(all_inputs):
- return bm.for_loop(partial(self._step_func_fit, shared_args),
- all_inputs,
- jit=self.jit['fit'])
-
- self._f_fit_compiled[shared_kwargs_str] = run_func
- return self._f_fit_compiled[shared_kwargs_str]
-
- def _step_func_fit(self, shared_args, t, i, x, ys):
- shared = tools.DotDict(t=t, dt=self.dt, i=i)
- shared.update(shared_args)
- share.save(**shared)
+ def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None):
+ if shared_args is None:
+ shared_args = dict()
+ share.save(t=i * self.dt, dt=self.dt, i=i, **shared_args)
# input step
self.target.clear_input()
- self._step_func_input(shared)
+ self._step_func_input()
# update step
- args = () if x is None else (x, )
- out = self.target(*args)
+ out = self.target(*xs)
# monitor step
- monitors = self._step_func_monitor(shared)
+ monitors = self._step_func_monitor()
for node in self.train_nodes:
fit_record = monitors.pop(f'{node.name}-fit_record')
target = ys[node.name]
@@ -275,32 +257,32 @@ def _step_func_fit(self, shared_args, t, i, x, ys):
def _check_interface(self):
for node in self.train_nodes:
if not hasattr(node, 'online_fit'):
- raise NoImplementationError(
- f'The node \n\n{node}\n\n'
- f'is set to be trainable with {self.__class__.__name__} method. '
- f'However, it does not implement the required training '
- f'interface "online_fit()" function. '
- )
+ raise NoImplementationError(
+ f'The node \n\n{node}\n\n'
+ f'is set to be trainable with {self.__class__.__name__} method. '
+ f'However, it does not implement the required training '
+ f'interface "online_fit()" function. '
+ )
if not hasattr(node, 'online_init'):
- raise NoImplementationError(
- f'The node \n\n{node}\n\n'
- f'is set to be trainable with {self.__class__.__name__} method. '
- f'However, it does not implement the required training '
- f'interface "online_init()" function. '
- )
-
- def _step_func_monitor(self, shared):
+ raise NoImplementationError(
+ f'The node \n\n{node}\n\n'
+ f'is set to be trainable with {self.__class__.__name__} method. '
+ f'However, it does not implement the required training '
+ f'interface "online_init()" function. '
+ )
+
+ def _step_func_monitor(self):
res = dict()
for key, val in self._monitors.items():
if callable(val):
- res[key] = val(shared)
+ res[key] = _call_fun_with_share(val)
else:
(variable, idx) = val
if idx is None:
res[key] = variable.value
else:
res[key] = variable[bm.asarray(idx)]
- if shared.get('fit', False):
+ if share.load('fit'):
for node in self.train_nodes:
res[f'{node.name}-fit_record'] = node.fit_record
return res
From 520d828ef9b4f49714a1b9b1817f7fb6199f0c64 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 21 Jul 2023 10:22:15 +0800
Subject: [PATCH 057/326] hashable `brainpy.tools.DotDict`
---
brainpy/_src/tools/dicts.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/brainpy/_src/tools/dicts.py b/brainpy/_src/tools/dicts.py
index 75013b82b..97b869372 100644
--- a/brainpy/_src/tools/dicts.py
+++ b/brainpy/_src/tools/dicts.py
@@ -217,6 +217,9 @@ def unique(self):
gather[k] = v
return gather
+ def __hash__(self):
+ return hash(tuple(sorted(self.items())))
+
register_pytree_node(
DotDict,
From 15ae3aea39ce14d505eec2b3d4b2afdd4efdc470 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 21 Jul 2023 10:22:25 +0800
Subject: [PATCH 058/326] updates
---
brainpy/__init__.py | 8 +++--
brainpy/_src/delay.py | 6 ++--
brainpy/_src/deprecations.py | 34 ++++++++++++++++++
brainpy/_src/dyn/projections/aligns.py | 10 +++---
brainpy/_src/dynsys.py | 20 ++---------
brainpy/_src/math/ndarray.py | 2 +-
.../_src/math/object_transform/controls.py | 12 +++++--
brainpy/dyn/channels.py | 36 +++++++++----------
8 files changed, 80 insertions(+), 48 deletions(-)
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index 77302e150..7bba216f5 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -65,7 +65,7 @@
Network = DynSysGroup
# delays
from brainpy._src.delay import (
- VariDelay as VariDelay,
+ VarDelay as VarDelay,
)
# building blocks
@@ -129,12 +129,16 @@
from brainpy._add_deprecations import deprecation_getattr2
__deprecations = {
+ 'Module': ('brainpy.Module', 'brainpy.DynamicalSystem', DynamicalSystem),
+ 'Channel': ('brainpy.Channel', 'brainpy.dyn.IonChannel', dyn.IonChannel),
+ 'NeuGroup': ('brainpy.NeuGroup', 'brainpy.dyn.NeuDyn', dyn.NeuDyn),
+ 'SynConn': ('brainpy.SynConn', 'brainpy.dyn.SynConn', dyn.SynConn),
'Container': ('brainpy.Container', 'brainpy.DynSysGroup', DynSysGroup),
+
'optimizers': ('brainpy.optimizers', 'brainpy.optim', optim),
'TensorCollector': ('brainpy.TensorCollector', 'brainpy.ArrayCollector', ArrayCollector),
'SynSTP': ('brainpy.SynSTP', 'brainpy.synapses.SynSTP', synapses.SynSTP),
'SynOut': ('brainpy.SynOut', 'brainpy.synapses.SynOut', synapses.SynOut),
- 'SynConn': ('brainpy.SynConn', 'brainpy.dyn.SynConn', dyn.SynConn),
'TwoEndConn': ('brainpy.TwoEndConn', 'brainpy.synapses.TwoEndConn', synapses.TwoEndConn),
'CondNeuGroup': ('brainpy.CondNeuGroup', 'brainpy.syn.CondNeuGroup', dyn.CondNeuGroup),
}
diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py
index bac40e53f..7feb6eb03 100644
--- a/brainpy/_src/delay.py
+++ b/brainpy/_src/delay.py
@@ -21,7 +21,7 @@
__all__ = [
'Delay',
- 'VariDelay',
+ 'VarDelay',
'DataDelay',
'DelayAccess',
]
@@ -432,7 +432,7 @@ def _check_target_sharding(sharding, ndim, mode: bm.Mode):
return sharding
-class VariDelay(Delay):
+class VarDelay(Delay):
"""Generate Delays for the given :py:class:`~.Variable` instance.
The data in this delay variable is arranged as::
@@ -690,7 +690,7 @@ def _init_data(self, length: int, batch_size: int = None):
self.data[:] = self._init((length,) + self.target.shape, dtype=self.target.dtype)
-class DataDelay(VariDelay):
+class DataDelay(VarDelay):
not_desc_params = ('time', 'entries')
def __init__(
diff --git a/brainpy/_src/deprecations.py b/brainpy/_src/deprecations.py
index 1734694e9..a71739458 100644
--- a/brainpy/_src/deprecations.py
+++ b/brainpy/_src/deprecations.py
@@ -8,6 +8,40 @@
]
+_update_deprecate_msg = '''
+From brainpy>=2.4.3, update() function no longer needs to receive a global shared argument.
+
+Instead of using:
+
+ def update(self, tdi, *args, **kwagrs):
+ t = tdi['t']
+ ...
+
+Please use:
+
+ def update(self, *args, **kwagrs):
+ t = bp.share['t']
+ ...
+'''
+
+
+_input_deprecate_msg = '''
+From brainpy>=2.4.3, input() function no longer needs to receive a global shared argument.
+
+Instead of using:
+
+ def input(tdi):
+ ...
+
+Please use:
+
+ def input():
+ t = bp.share['t']
+ ...
+'''
+
+
+
def _deprecate(msg):
warnings.simplefilter('always', DeprecationWarning) # turn off filter
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py
index 907a144f2..15b92b0d4 100644
--- a/brainpy/_src/dyn/projections/aligns.py
+++ b/brainpy/_src/dyn/projections/aligns.py
@@ -3,7 +3,7 @@
import jax
from brainpy import math as bm, check
-from brainpy._src.delay import Delay, VariDelay, DataDelay, DelayAccess
+from brainpy._src.delay import Delay, VarDelay, DataDelay, DelayAccess
from brainpy._src.dynsys import DynamicalSystem, Projection, Dynamic, Sequential
from brainpy._src.mixin import JointType, ParamDescInit, ReturnInfo, AutoDelaySupp, BindCondData, AlignPost
@@ -54,7 +54,7 @@ def update(self):
def _init_delay(info: Union[bm.Variable, ReturnInfo]) -> Delay:
if isinstance(info, bm.Variable):
- return VariDelay(info)
+ return VarDelay(info)
elif isinstance(info, ReturnInfo):
if isinstance(info.batch_or_mode, int):
shape = (info.batch_or_mode,) + tuple(info.size)
@@ -106,7 +106,7 @@ def __init__(self):
super().__init__()
self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
- self.delay = bp.VariableDelay(self.N.spike, entries={'I': None})
+ self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
self.syn1 = bp.dyn.Expon(size=3200, tau=5.)
self.syn2 = bp.dyn.Expon(size=800, tau=10.)
self.E = bp.dyn.VanillaProj(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
@@ -180,7 +180,7 @@ def __init__(self):
super().__init__()
self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
- self.delay = bp.VariableDelay(self.N.spike, entries={'I': None})
+ self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
self.E = bp.dyn.ProjAlignPostMg1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
syn=bp.dyn.Expon.desc(size=4000, tau=5.),
out=bp.dyn.COBA.desc(E=0.),
@@ -374,7 +374,7 @@ def __init__(self):
super().__init__()
self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
- self.delay = bp.VariableDelay(self.N.spike, entries={'I': None})
+ self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
self.E = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
syn=bp.dyn.Expon(size=4000, tau=5.),
out=bp.dyn.COBA(E=0.),
diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py
index f14302040..1f8b105ca 100644
--- a/brainpy/_src/dynsys.py
+++ b/brainpy/_src/dynsys.py
@@ -13,6 +13,7 @@
from brainpy._src.mixin import AutoDelaySupp, Container, DelayRegister, global_delay_data
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape
+from brainpy._src.deprecations import _update_deprecate_msg
share = None
@@ -29,21 +30,6 @@
SLICE_VARS = 'slice_vars'
-_update_deprecate_msg = '''
-From brainpy>=2.4.3, update() function no longer needs to receive a global shared argument.
-
-Instead of using:
-
- def update(self, tdi, *args, **kwagrs):
- ...
-
-Please use:
-
- def update(self, *args, **kwagrs):
- t = bp.share['t']
- ...
-'''
-
def not_pass_shared(func: Callable):
"""Label the update function as the one without passing shared arguments.
@@ -305,14 +291,14 @@ def __repr__(self):
def __call__(self, *args, **kwargs):
"""The shortcut to call ``update`` methods."""
- # update ``before_updates``
+ # ``before_updates``
for model in self.before_updates.values():
model()
# update the model self
ret = self.update(*args, **kwargs)
- # update ``after_updates``
+ # ``after_updates``
for model in self.after_updates.values():
model(ret)
return ret
diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py
index 0872d0bc8..820ffc36a 100644
--- a/brainpy/_src/math/ndarray.py
+++ b/brainpy/_src/math/ndarray.py
@@ -748,7 +748,7 @@ def split(self, indices_or_sections, axis=0):
sub-arrays : list of ndarrays
A list of sub-arrays as views into `ary`.
"""
- return [_return(a) for a in self.value.split(indices_or_sections, axis=axis)]
+ return [_return(a) for a in jnp.split(self.value, indices_or_sections, axis=axis)]
def take(self, indices, axis=None, mode=None):
"""Return an array formed from the elements of a at the given indices."""
diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py
index 00fd331b1..b3ef525a6 100644
--- a/brainpy/_src/math/object_transform/controls.py
+++ b/brainpy/_src/math/object_transform/controls.py
@@ -722,7 +722,7 @@ def _get_for_loop_transform(
progress_bar: bool,
remat: bool,
reverse: bool,
- unroll: int
+ unroll: int,
):
def fun2scan(carry, x):
for k in dyn_vars.keys():
@@ -753,6 +753,7 @@ def for_loop(
remat: bool = False,
jit: Optional[bool] = None,
progress_bar: bool = False,
+ unroll_kwargs: Optional[Dict] = None,
# deprecated
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
@@ -845,6 +846,8 @@ def for_loop(
.. deprecated:: 2.4.0
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
+ unroll_kwargs: dict
+ The keyword arguments without unrolling.
Returns
-------
@@ -855,6 +858,9 @@ def for_loop(
dynvar_deprecation(dyn_vars)
node_deprecation(child_objs)
+ if unroll_kwargs is None:
+ unroll_kwargs = dict()
+
if not isinstance(operands, (list, tuple)):
operands = (operands,)
@@ -885,7 +891,9 @@ def for_loop(
dyn_vars = VariableStack()
# TODO: cache mechanism?
- transform = _get_for_loop_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll)
+ transform = _get_for_loop_transform(body_fun, dyn_vars, bar,
+ progress_bar, remat, reverse,
+ unroll)
if jit:
dyn_vals, out_vals = transform(operands)
else:
diff --git a/brainpy/dyn/channels.py b/brainpy/dyn/channels.py
index 03d8e979f..41ed7856d 100644
--- a/brainpy/dyn/channels.py
+++ b/brainpy/dyn/channels.py
@@ -1,29 +1,29 @@
from brainpy._src.dyn.channels.base import (
- IonChannel,
+ IonChannel as IonChannel,
)
from brainpy._src.dyn.channels.calcium import (
- CalciumChannel,
- ICaN_IS2008,
- ICaT_HM1992,
- ICaT_HP1992,
- ICaHT_HM1992,
- ICaHT_Re1993,
- ICaL_IS2008,
+ CalciumChannel as CalciumChannel,
+ ICaN_IS2008 as ICaN_IS2008,
+ ICaT_HM1992 as ICaT_HM1992,
+ ICaT_HP1992 as ICaT_HP1992,
+ ICaHT_HM1992 as ICaHT_HM1992,
+ ICaHT_Re1993 as ICaHT_Re1993,
+ ICaL_IS2008 as ICaL_IS2008,
)
from brainpy._src.dyn.channels.potassium import (
- PotassiumChannel,
- IKDR_Ba2002v2,
- IK_TM1991v2,
- IK_HH1952v2,
- IKA1_HM1992v2,
- IKA2_HM1992v2,
- IKK2A_HM1992v2,
- IKK2B_HM1992v2,
- IKNI_Ya1989v2,
- IK_Leak,
+ PotassiumChannel as PotassiumChannel,
+ IKDR_Ba2002v2 as IKDR_Ba2002v2,
+ IK_TM1991v2 as IK_TM1991v2,
+ IK_HH1952v2 as IK_HH1952v2,
+ IKA1_HM1992v2 as IKA1_HM1992v2,
+ IKA2_HM1992v2 as IKA2_HM1992v2,
+ IKK2A_HM1992v2 as IKK2A_HM1992v2,
+ IKK2B_HM1992v2 as IKK2B_HM1992v2,
+ IKNI_Ya1989v2 as IKNI_Ya1989v2,
+ IK_Leak as IK_Leak,
)
from brainpy._src.dyn.channels.potassium_compatible import (
IKDR_Ba2002,
From 4cba72dc71edda3ecc2fd747c781c1fdffa5ca01 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 21 Jul 2023 11:11:21 +0800
Subject: [PATCH 059/326] support `brainpy.math.for_loop` with the keyword
`unroll_kwargs`
---
.../_src/math/object_transform/controls.py | 13 +-
brainpy/_src/runners.py | 9 +-
brainpy/_src/running/runner.py | 1 -
brainpy/_src/tools/dicts.py | 46 +------
brainpy/_src/train/back_propagation.py | 2 +-
brainpy/_src/train/online.py | 2 +-
examples/dynamics_simulation/COBA.py | 1 -
tests/simulation/test_net_COBA.py | 118 ------------------
tests/training/test_ESN.py | 30 ++---
9 files changed, 31 insertions(+), 191 deletions(-)
delete mode 100644 tests/simulation/test_net_COBA.py
diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py
index b3ef525a6..19efbf1af 100644
--- a/brainpy/_src/math/object_transform/controls.py
+++ b/brainpy/_src/math/object_transform/controls.py
@@ -723,11 +723,12 @@ def _get_for_loop_transform(
remat: bool,
reverse: bool,
unroll: int,
+ unroll_kwargs: tools.DotDict
):
def fun2scan(carry, x):
for k in dyn_vars.keys():
dyn_vars[k]._value = carry[k]
- results = body_fun(*x)
+ results = body_fun(*x, **unroll_kwargs)
if progress_bar:
id_tap(lambda *arg: bar.update(), ())
return dyn_vars.dict_data(), results
@@ -860,6 +861,7 @@ def for_loop(
if unroll_kwargs is None:
unroll_kwargs = dict()
+ unroll_kwargs = tools.DotDict(unroll_kwargs)
if not isinstance(operands, (list, tuple)):
operands = (operands,)
@@ -871,19 +873,20 @@ def for_loop(
if jit is None: # jax disable jit
jit = not jax.config.jax_disable_jit
- dyn_vars = get_stack_cache(body_fun)
+ dyn_vars = get_stack_cache((body_fun, unroll_kwargs))
if jit:
if dyn_vars is None:
# TODO: better cache mechanism?
with new_transform('for_loop'):
with VariableStack() as dyn_vars:
transform = _get_for_loop_transform(body_fun, VariableStack(), bar,
- progress_bar, remat, reverse, unroll)
+ progress_bar, remat, reverse, unroll,
+ unroll_kwargs)
if current_transform_number() > 1:
rets = transform(operands)
else:
rets = jax.eval_shape(transform, operands)
- cache_stack(body_fun, dyn_vars) # cache
+ cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache
if current_transform_number():
return rets[1]
del rets
@@ -893,7 +896,7 @@ def for_loop(
# TODO: cache mechanism?
transform = _get_for_loop_transform(body_fun, dyn_vars, bar,
progress_bar, remat, reverse,
- unroll)
+ unroll, unroll_kwargs)
if jit:
dyn_vals, out_vals = transform(operands)
else:
diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py
index 42b40b88e..73cf7f43d 100644
--- a/brainpy/_src/runners.py
+++ b/brainpy/_src/runners.py
@@ -466,7 +466,7 @@ def predict(
inputs = tree_map(lambda x: jnp.moveaxis(x, 0, 1), inputs)
# build monitor
- for key in self.mon.var_names:
+ for key in self._monitors.keys():
self.mon[key] = [] # reshape the monitor items
# init progress bar
@@ -492,7 +492,7 @@ def predict(
# post-running for monitors
if self._memory_efficient:
self.mon['ts'] = indices * self.dt + self.t0
- for key in self.mon.var_names:
+ for key in self._monitors.keys():
self.mon[key] = np.asarray(self.mon[key])
else:
hists['ts'] = indices * self.dt + self.t0
@@ -658,6 +658,7 @@ def _fun_predict(self, indices, *inputs, shared_args=None):
return outs, None
else:
- return bm.for_loop(functools.partial(self._step_func_predict, shared_args=shared_args),
+ return bm.for_loop(self._step_func_predict,
(indices, *inputs),
- jit=self.jit['predict'])
+ jit=self.jit['predict'],
+ unroll_kwargs={'shared_args': shared_args})
diff --git a/brainpy/_src/running/runner.py b/brainpy/_src/running/runner.py
index 2a2de3d3f..1b07e4e5a 100644
--- a/brainpy/_src/running/runner.py
+++ b/brainpy/_src/running/runner.py
@@ -118,7 +118,6 @@ def __init__(
# monitor for user access
self.mon = DotDict()
- self.mon['var_names'] = tuple(self._monitors.keys())
# progress bar
assert isinstance(progress_bar, bool), 'Must be a boolean variable.'
diff --git a/brainpy/_src/tools/dicts.py b/brainpy/_src/tools/dicts.py
index 97b869372..e8e207ae4 100644
--- a/brainpy/_src/tools/dicts.py
+++ b/brainpy/_src/tools/dicts.py
@@ -42,64 +42,20 @@ class DotDict(dict):
>>> f(d)
TypeError: Argument 'a' of type is not a valid JAX type.
- At this moment, you can label this attribute `names` as not a key in the dictionary
- by using the syntax::
-
- >>> d.add_attr_not_key('names')
- >>> f(d)
- {'a': DeviceArray(10, dtype=int32, weak_type=True),
- 'b': DeviceArray(20, dtype=int32, weak_type=True),
- 'c': DeviceArray(30, dtype=int32, weak_type=True)}
-
"""
- '''Used to exclude variables that '''
- attrs_not_keys = ('attrs_not_keys', 'var_names')
-
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
- self.var_names = ()
def copy(self) -> 'DotDict':
return type(self)(super().copy())
- def keys(self):
- """Retrieve all keys in the dict, excluding ignored keys."""
- keys = []
- for k in super(DotDict, self).keys():
- if k not in self.attrs_not_keys:
- keys.append(k)
- return tuple(keys)
-
- def values(self):
- """Retrieve all values in the dict, excluding values of ignored keys."""
- values = []
- for k, v in super(DotDict, self).items():
- if k not in self.attrs_not_keys:
- values.append(v)
- return tuple(values)
-
- def items(self):
- """Retrieve all items in the dict, excluding ignored items."""
- items = []
- for k, v in super(DotDict, self).items():
- if k not in self.attrs_not_keys:
- items.append((k, v))
- return items
-
def to_numpy(self):
"""Change all values to numpy arrays."""
for key in tuple(self.keys()):
self[key] = np.asarray(self[key])
- def add_attr_not_key(self, *args):
- """Add excluded attribute when retrieving dictionary keys. """
- for arg in args:
- if not isinstance(arg, str):
- raise TypeError('Only support string.')
- self.attrs_not_keys += args
-
def update(self, *args, **kwargs):
super().update(*args, **kwargs)
return self
@@ -179,7 +135,7 @@ def subset(self, var_type):
>>> import brainpy as bp
>>>
- >>> some_collector = Collector()
+ >>> some_collector = DotDict()
>>>
>>> # get all trainable variables
>>> some_collector.subset(bp.math.TrainVar)
diff --git a/brainpy/_src/train/back_propagation.py b/brainpy/_src/train/back_propagation.py
index 806b68693..6f65783fe 100644
--- a/brainpy/_src/train/back_propagation.py
+++ b/brainpy/_src/train/back_propagation.py
@@ -605,7 +605,7 @@ def predict(
self.target.reset_state(self._get_input_batch_size(xs=inputs))
self.reset_state()
# init monitor
- for key in self.mon.var_names:
+ for key in self._monitors.keys():
self.mon[key] = [] # reshape the monitor items
# prediction
if not isinstance(inputs, (tuple, list)):
diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py
index 08214e7d7..e028f9c62 100644
--- a/brainpy/_src/train/online.py
+++ b/brainpy/_src/train/online.py
@@ -177,7 +177,7 @@ def fit(
is_leaf=lambda y: isinstance(y, bm.Array))
# init monitor
- for key in self.mon.var_names:
+ for key in self._monitors.keys():
self.mon[key] = [] # reshape the monitor items
# init progress bar
diff --git a/examples/dynamics_simulation/COBA.py b/examples/dynamics_simulation/COBA.py
index 043ede354..5c49cfc9b 100644
--- a/examples/dynamics_simulation/COBA.py
+++ b/examples/dynamics_simulation/COBA.py
@@ -168,7 +168,6 @@ def run3():
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
-
def run4():
net = EICOBA_PostAlign(3200, 800)
runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
diff --git a/tests/simulation/test_net_COBA.py b/tests/simulation/test_net_COBA.py
deleted file mode 100644
index 941f233a0..000000000
--- a/tests/simulation/test_net_COBA.py
+++ /dev/null
@@ -1,118 +0,0 @@
-import brainpy as bp
-
-import unittest
-
-show = False
-
-class EINet(bp.DynamicalSystem):
- def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None):
- super().__init__()
-
- self.bg_exc = e_input
- self.bg_inh = i_input
-
- # network size
- num_exc = int(3200 * scale)
- num_inh = int(800 * scale)
-
- # neurons
- pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.), input_var=False)
- self.E = bp.neurons.LIF(num_exc, **pars)
- self.I = bp.neurons.LIF(num_inh, **pars)
-
- # synapses
- we = 0.6 / scale # excitatory synaptic weight (voltage)
- wi = 6.7 / scale # inhibitory synaptic weight
- self.E2E = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.E.size, post=self.E.size),
- g_max=we, tau=5., out=bp.experimental.COBA(E=0.)
- )
- self.E2I = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.E.size, post=self.I.size, ),
- g_max=we, tau=5., out=bp.experimental.COBA(E=0.)
- )
- self.I2E = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.I.size, post=self.E.size),
- g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.)
- )
- self.I2I = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.I.size, post=self.I.size),
- g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.)
- )
- self.delayE = bp.Delay(self.E.spike, entries={'E': delay})
- self.delayI = bp.Delay(self.I.spike, entries={'I': delay})
-
- def update(self):
- e_spike = self.delayE.at('E')
- i_spike = self.delayI.at('I')
- e_inp = self.E2E(e_spike, self.E.V) + self.I2E(i_spike, self.E.V) + self.bg_exc
- i_inp = self.I2I(i_spike, self.I.V) + self.E2I(e_spike, self.I.V) + self.bg_inh
- self.delayE(self.E(e_inp))
- self.delayI(self.I(i_inp))
-
-
-class EINetv2(bp.DynamicalSystem):
- def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None):
- super().__init__()
-
- self.bg_exc = e_input
- self.bg_inh = i_input
-
- # network size
- num_exc = int(3200 * scale)
- num_inh = int(800 * scale)
-
- # neurons
- pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
- V_initializer=bp.init.Normal(-55., 2.), input_var=False)
- self.E = bp.neurons.LIF(num_exc, **pars)
- self.I = bp.neurons.LIF(num_inh, **pars)
-
- # synapses
- we = 0.6 / scale # excitatory synaptic weight (voltage)
- wi = 6.7 / scale # inhibitory synaptic weight
- self.E2E = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.E.size, post=self.E.size),
- g_max=we, tau=5., out=bp.experimental.COBA(E=0.)
- )
- self.E2I = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.E.size, post=self.I.size, ),
- g_max=we, tau=5., out=bp.experimental.COBA(E=0.)
- )
- self.I2E = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.I.size, post=self.E.size),
- g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.)
- )
- self.I2I = bp.experimental.Exponential(
- bp.conn.FixedProb(0.02, pre=self.I.size, post=self.I.size),
- g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.)
- )
- bp.share.save('E-spike', bp.Delay(self.E.spike, entries={'E': delay}))
- bp.share.save('I-spike', bp.Delay(self.I.spike, entries={'I': delay}))
-
- def update(self):
- e_spike = bp.share.load('E-spike').at('E')
- i_spike = bp.share.load('I-spike').at('I')
- e_inp = self.E2E(e_spike, self.E.V) + self.I2E(i_spike, self.E.V) + self.bg_exc
- i_inp = self.I2I(i_spike, self.I.V) + self.E2I(e_spike, self.I.V) + self.bg_inh
- self.E(e_inp)
- self.I(i_inp)
-
-
-class TestCOBA(unittest.TestSuite):
- def test1(self):
- net = EINet(delay=0., scale=2. if show else 0.1)
- runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
- r = runner.run(1., eval_time=True)
- if show:
- bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
- bp.math.clear_buffer_memory()
-
- def test2(self):
- net = EINetv2(delay=0., scale=2. if show else 0.1)
- runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
- r = runner.run(1., eval_time=True)
- if show:
- bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
- bp.math.clear_buffer_memory()
diff --git a/tests/training/test_ESN.py b/tests/training/test_ESN.py
index df36aa5f3..d543bc25e 100644
--- a/tests/training/test_ESN.py
+++ b/tests/training/test_ESN.py
@@ -6,17 +6,17 @@
class ESN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden, num_out):
super(ESN, self).__init__()
- self.r = bp.layers.Reservoir(num_in,
- num_hidden,
- Win_initializer=bp.init.Uniform(-0.1, 0.1),
- Wrec_initializer=bp.init.Normal(scale=0.1),
- in_connectivity=0.02,
- rec_connectivity=0.02,
- comp_type='dense')
- self.o = bp.layers.Dense(num_hidden,
- num_out,
- W_initializer=bp.init.Normal(),
- mode=bm.training_mode)
+ self.r = bp.dnn.Reservoir(num_in,
+ num_hidden,
+ Win_initializer=bp.init.Uniform(-0.1, 0.1),
+ Wrec_initializer=bp.init.Normal(scale=0.1),
+ in_connectivity=0.02,
+ rec_connectivity=0.02,
+ comp_type='dense')
+ self.o = bp.dnn.Dense(num_hidden,
+ num_out,
+ W_initializer=bp.init.Normal(),
+ mode=bm.training_mode)
def update(self, x):
return x >> self.r >> self.o
@@ -26,10 +26,10 @@ class NGRC(bp.DynamicalSystem):
def __init__(self, num_in, num_out):
super(NGRC, self).__init__()
- self.r = bp.layers.NVAR(num_in, delay=2, order=2)
- self.o = bp.layers.Dense(self.r.num_out, num_out,
- W_initializer=bp.init.Normal(0.1),
- mode=bm.training_mode)
+ self.r = bp.dnn.NVAR(num_in, delay=2, order=2)
+ self.o = bp.dnn.Dense(self.r.num_out, num_out,
+ W_initializer=bp.init.Normal(0.1),
+ mode=bm.training_mode)
def update(self, x):
return x >> self.r >> self.o
From 0c5b64f6964a7c61a2eb33c10fdfb9636b9eee9f Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 21 Jul 2023 11:12:44 +0800
Subject: [PATCH 060/326] fix tests
---
brainpy/__init__.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index 7bba216f5..1e98ab4a2 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -131,7 +131,6 @@
__deprecations = {
'Module': ('brainpy.Module', 'brainpy.DynamicalSystem', DynamicalSystem),
'Channel': ('brainpy.Channel', 'brainpy.dyn.IonChannel', dyn.IonChannel),
- 'NeuGroup': ('brainpy.NeuGroup', 'brainpy.dyn.NeuDyn', dyn.NeuDyn),
'SynConn': ('brainpy.SynConn', 'brainpy.dyn.SynConn', dyn.SynConn),
'Container': ('brainpy.Container', 'brainpy.DynSysGroup', DynSysGroup),
From 60eebd358e5c457cebd896e0ce78b9bd9b418deb Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 21 Jul 2023 11:28:19 +0800
Subject: [PATCH 061/326] update CI
---
.github/workflows/CI-models.yml | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/.github/workflows/CI-models.yml b/.github/workflows/CI-models.yml
index 1b416ccc4..f5681cd75 100644
--- a/.github/workflows/CI-models.yml
+++ b/.github/workflows/CI-models.yml
@@ -117,7 +117,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ["3.8", "3.9", "3.10", "3.11"]
+ python-version: ["3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v2
@@ -128,8 +128,6 @@ jobs:
- name: Install dependencies
run: |
python -m pip install numpy>=1.21.0
- python -m pip install "jaxlib==0.4.10" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
- python -m pip install jax==0.4.10
python -m pip install -r requirements-dev.txt
python -m pip install tqdm brainpylib
pip uninstall brainpy -y
From 90a51a5f124483454437edd7d3819fecbe68aa4b Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 21 Jul 2023 16:46:16 +0800
Subject: [PATCH 062/326] minor updates
---
brainpy/_src/dyn/ions/base.py | 4 ++--
brainpy/_src/dyn/neurons/lif.py | 4 ++--
brainpy/_src/dyn/projections/aligns.py | 2 +-
3 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/brainpy/_src/dyn/ions/base.py b/brainpy/_src/dyn/ions/base.py
index 175b9413e..74dd803ff 100644
--- a/brainpy/_src/dyn/ions/base.py
+++ b/brainpy/_src/dyn/ions/base.py
@@ -171,8 +171,8 @@ def current(self, V, C=None, E=None, external: bool = False):
Args:
V: The membrane potential.
- C: The ion concentration.
- E: The reversal potential.
+ C: The given ion concentration.
+ E: The given reversal potential.
external: Include the external current.
Returns:
diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py
index f8ba045fd..6c78280ac 100644
--- a/brainpy/_src/dyn/neurons/lif.py
+++ b/brainpy/_src/dyn/neurons/lif.py
@@ -115,7 +115,7 @@ def __init__(
def derivative(self, V, t, I):
for out in self.cur_inputs.values():
- I += out(V)
+ I = I + out(V)
return (-V + self.V_rest + self.R * I) / self.tau
def reset_state(self, batch_size=None):
@@ -141,7 +141,7 @@ def derivative(self, V, t, I):
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
- x += out(self.V.value)
+ x = x + out(self.V.value)
super().update(x)
diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py
index 15b92b0d4..4e907f086 100644
--- a/brainpy/_src/dyn/projections/aligns.py
+++ b/brainpy/_src/dyn/projections/aligns.py
@@ -4,7 +4,7 @@
from brainpy import math as bm, check
from brainpy._src.delay import Delay, VarDelay, DataDelay, DelayAccess
-from brainpy._src.dynsys import DynamicalSystem, Projection, Dynamic, Sequential
+from brainpy._src.dynsys import DynamicalSystem, Projection, Dynamic
from brainpy._src.mixin import JointType, ParamDescInit, ReturnInfo, AutoDelaySupp, BindCondData, AlignPost
__all__ = [
From 934c676e98138a52e7af447a0959c4be6f5081d1 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 21 Jul 2023 16:46:22 +0800
Subject: [PATCH 063/326] update examples
---
examples/dynamics_simulation/COBA.py | 21 +++--
examples/dynamics_simulation/COBA_parallel.py | 77 +++++++++++++++++++
examples/dynamics_simulation/hh_model.py | 9 +++
3 files changed, 99 insertions(+), 8 deletions(-)
create mode 100644 examples/dynamics_simulation/COBA_parallel.py
diff --git a/examples/dynamics_simulation/COBA.py b/examples/dynamics_simulation/COBA.py
index 5c49cfc9b..40e01b86f 100644
--- a/examples/dynamics_simulation/COBA.py
+++ b/examples/dynamics_simulation/COBA.py
@@ -56,12 +56,16 @@ def update(self):
class EICOBA_PostAlign(bp.DynamicalSystem):
- def __init__(self, num_exc, num_inh, inp=20.):
+ def __init__(self, num_exc, num_inh, inp=20., ltc=True):
super().__init__()
self.inp = inp
- self.E = bp.dyn.LifRefLTC(num_exc, **neu_pars)
- self.I = bp.dyn.LifRefLTC(num_inh, **neu_pars)
+ if ltc:
+ self.E = bp.dyn.LifRefLTC(num_exc, **neu_pars)
+ self.I = bp.dyn.LifRefLTC(num_inh, **neu_pars)
+ else:
+ self.E = bp.dyn.LifRef(num_exc, **neu_pars)
+ self.I = bp.dyn.LifRef(num_inh, **neu_pars)
self.E2E = bp.dyn.ProjAlignPostMg2(
pre=self.E,
@@ -145,10 +149,10 @@ def run1():
with bm.environment(mode=bm.BatchingMode(10)):
net = EICOBA_PostAlign(3200, 800)
runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
+ bp.visualize.raster_plot(runner.mon['ts'], runner.mon['E.spike'][0], show=True)
print(runner.run(100., eval_time=True))
print(runner.mon['E.spike'].shape)
print(runner.mon['ts'].shape)
- bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'][0], show=True)
def run2():
@@ -169,14 +173,15 @@ def run3():
def run4():
- net = EICOBA_PostAlign(3200, 800)
+ bm.set(dt=0.5)
+ net = EICOBA_PostAlign(3200, 800, ltc=True)
runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike})
print(runner.run(100., eval_time=True))
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True)
if __name__ == '__main__':
- run1()
- run2()
- run3()
+ # run1()
+ # run2()
+ # run3()
run4()
diff --git a/examples/dynamics_simulation/COBA_parallel.py b/examples/dynamics_simulation/COBA_parallel.py
new file mode 100644
index 000000000..e7b0d15c4
--- /dev/null
+++ b/examples/dynamics_simulation/COBA_parallel.py
@@ -0,0 +1,77 @@
+import jax
+
+import brainpy as bp
+import brainpy.math as bm
+
+bm.set_host_device_count(4)
+
+
+class EINet1(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ self.N = bp.dyn.LifRefLTC(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.),
+ sharding=[bm.sharding.NEU_AXIS])
+ self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
+ self.E = bp.dyn.ProjAlignPostMg1(
+ comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
+ syn=bp.dyn.Expon.desc(size=4000, tau=5., sharding=[bm.sharding.NEU_AXIS]),
+ out=bp.dyn.COBA.desc(E=0.),
+ post=self.N
+ )
+ self.I = bp.dyn.ProjAlignPostMg1(
+ comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
+ syn=bp.dyn.Expon.desc(size=4000, tau=10., sharding=[bm.sharding.NEU_AXIS]),
+ out=bp.dyn.COBA.desc(E=-80.),
+ post=self.N
+ )
+
+ def update(self, input):
+ spk = self.delay.at('I')
+ self.E(spk[:3200])
+ self.I(spk[3200:])
+ self.delay(self.N(input))
+ return self.N.spike.value
+
+
+class EINet2(bp.DynSysGroup):
+ def __init__(self):
+ super().__init__()
+ self.N = bp.dyn.LifRefLTC(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
+ V_initializer=bp.init.Normal(-55., 2.),
+ sharding=[bm.sharding.NEU_AXIS])
+ self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
+ self.E = bp.dyn.ProjAlignPostMg1(
+ comm=bp.dnn.MaskedLinear(bp.conn.FixedProb(0.02, pre=3200, post=4000), weight=0.6,
+ sharding=[None, bm.sharding.NEU_AXIS]),
+ syn=bp.dyn.Expon.desc(size=4000, tau=5., sharding=[bm.sharding.NEU_AXIS]),
+ out=bp.dyn.COBA.desc(E=0.),
+ post=self.N
+ )
+ self.I = bp.dyn.ProjAlignPostMg1(
+ comm=bp.dnn.MaskedLinear(bp.conn.FixedProb(0.02, pre=800, post=4000), weight=0.6,
+ sharding=[None, bm.sharding.NEU_AXIS]),
+ syn=bp.dyn.Expon.desc(size=4000, tau=10., sharding=[bm.sharding.NEU_AXIS]),
+ out=bp.dyn.COBA.desc(E=-80.),
+ post=self.N
+ )
+
+ def update(self, input):
+ spk = self.delay.at('I')
+ self.E(spk[:3200])
+ self.I(spk[3200:])
+ self.delay(self.N(input))
+ return self.N.spike.value
+
+
+@bm.jit
+def run(indexes):
+ return bm.for_loop(lambda i: model.step_run(i, 20.), indexes)
+
+
+with bm.sharding.device_mesh(jax.devices(), [bm.sharding.NEU_AXIS]):
+ # model = EINet1()
+ model = EINet2()
+ indices = bm.arange(1000)
+ spks = run(indices)
+bp.visualize.raster_plot(indices, spks, show=True)
diff --git a/examples/dynamics_simulation/hh_model.py b/examples/dynamics_simulation/hh_model.py
index 6b64a6c10..0343ae89c 100644
--- a/examples/dynamics_simulation/hh_model.py
+++ b/examples/dynamics_simulation/hh_model.py
@@ -18,6 +18,15 @@ def __init__(self, size):
self.IL = bp.channels.IL(size, E=-54.387, g_max=0.03)
+class HHLTC(bp.dyn.CondNeuGroupLTC):
+ def __init__(self, size):
+ super().__init__(size)
+
+ self.INa = bp.channels.INa_HH1952(size)
+ self.IK = bp.channels.IK_HH1952(size)
+ self.IL = bp.channels.IL(size, E=-54.387, g_max=0.03)
+
+
class HHv2(bp.dyn.CondNeuGroupLTC):
def __init__(self, size):
super().__init__(size)
From 1880f8d91bb08ddd2ccf8efe67d1c77ecfbcc6d6 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 21 Jul 2023 20:03:46 +0800
Subject: [PATCH 064/326] fix ``lif`` model bugs and support two kinds of spike
reset: ``soft`` and ``hard``
---
brainpy/_src/dyn/neurons/base.py | 2 +
brainpy/_src/dyn/neurons/lif.py | 185 +++++++++++++++++++++----------
2 files changed, 129 insertions(+), 58 deletions(-)
diff --git a/brainpy/_src/dyn/neurons/base.py b/brainpy/_src/dyn/neurons/base.py
index de4317a83..4ea3ba4d2 100644
--- a/brainpy/_src/dyn/neurons/base.py
+++ b/brainpy/_src/dyn/neurons/base.py
@@ -29,6 +29,7 @@ def __init__(
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
):
super().__init__(size=size,
@@ -38,6 +39,7 @@ def __init__(
sharding=sharding,
method=method)
+ self.spk_reset = spk_reset
self.spk_fun = is_callable(spk_fun)
self.detach_spk = detach_spk
self._spk_type = spk_type
diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py
index 6c78280ac..0690fce85 100644
--- a/brainpy/_src/dyn/neurons/lif.py
+++ b/brainpy/_src/dyn/neurons/lif.py
@@ -77,6 +77,7 @@ def __init__(
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
init_var: bool = True,
@@ -96,7 +97,8 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
- spk_type=spk_type)
+ spk_type=spk_type,
+ spk_reset=spk_reset)
# parameters
self.V_rest = self.init_param(V_rest)
@@ -120,6 +122,7 @@ def derivative(self, V, t, I):
def reset_state(self, batch_size=None):
self.V = self.init_variable(self._V_initializer, batch_size)
+ self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size)
def update(self, x=None):
t = share.load('t')
@@ -128,6 +131,7 @@ def update(self, x=None):
# integrate membrane potential
self.V.value = self.integral(self.V.value, t, x, dt)
+
return self.V.value
def return_info(self):
@@ -142,7 +146,7 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x = x + out(self.V.value)
- super().update(x)
+ return super().update(x)
IF.__doc__ = IFLTC.__doc__ % ('', if_doc, pneu_doc, dpneu_doc)
@@ -183,6 +187,7 @@ def __init__(
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
init_var: bool = True,
@@ -204,7 +209,8 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
- spk_type=spk_type)
+ spk_type=spk_type,
+ spk_reset=spk_reset)
# parameters
self.V_rest = self.init_param(V_rest)
@@ -244,7 +250,12 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike = stop_gradient(spike) if self.detach_spk else spike
- V += (self.V_reset - V) * spike
+ if self.spk_reset == 'soft':
+ V -= (self.V_th - self.V_reset) * spike
+ elif self.spk_reset == 'hard':
+ V += (self.V_reset - V) * spike
+ else:
+ raise ValueError
else:
spike = V >= self.V_th
@@ -266,7 +277,7 @@ def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
x = x + out(self.V.value)
- super().update(x)
+ return super().update(x)
Lif.__doc__ = LifLTC.__doc__ % ('', lif_doc, pneu_doc, dpneu_doc)
@@ -310,6 +321,7 @@ def __init__(
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
detach_spk: bool = False,
+ spk_reset: str = 'soft',
method: str = 'exp_auto',
name: Optional[str] = None,
init_var: bool = True,
@@ -337,6 +349,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
+ spk_reset=spk_reset,
init_var=False,
@@ -387,7 +400,12 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike) if self.detach_spk else spike
- V += (self.V_reset - V) * spike_no_grad
+ if self.spk_reset == 'soft':
+ V -= (self.V_th - self.V_reset) * spike_no_grad
+ elif self.spk_reset == 'hard':
+ V += (self.V_reset - V) * spike_no_grad
+ else:
+ raise ValueError
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
if self.ref_var:
@@ -528,6 +546,7 @@ def __init__(
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
init_var: bool = True,
@@ -551,7 +570,9 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
- spk_type=spk_type)
+ spk_type=spk_type,
+ spk_reset=spk_reset)
+
# parameters
self.V_rest = self.init_param(V_rest)
self.V_reset = self.init_param(V_reset)
@@ -594,7 +615,12 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike = stop_gradient(spike) if self.detach_spk else spike
- V += (self.V_reset - V) * spike
+ if self.spk_reset == 'soft':
+ V -= (self.V_th - self.V_reset) * spike
+ elif self.spk_reset == 'hard':
+ V += (self.V_reset - V) * spike
+ else:
+ raise ValueError
else:
spike = V >= self.V_th
@@ -631,6 +657,7 @@ def __init__(
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
detach_spk: bool = False,
+ spk_reset: str = 'soft',
method: str = 'exp_auto',
name: Optional[str] = None,
init_var: bool = True,
@@ -660,6 +687,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
+ spk_reset=spk_reset,
init_var=False,
@@ -712,7 +740,12 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike) if self.detach_spk else spike
- V += (self.V_reset - V) * spike_no_grad
+ if self.spk_reset == 'soft':
+ V -= (self.V_th - self.V_reset) * spike_no_grad
+ elif self.spk_reset == 'hard':
+ V += (self.V_reset - V) * spike_no_grad
+ else:
+ raise ValueError
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
if self.ref_var:
@@ -834,6 +867,7 @@ def __init__(
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
init_var: bool = True,
@@ -861,7 +895,8 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
- spk_type=spk_type)
+ spk_type=spk_type,
+ spk_reset=spk_reset)
# parameters
self.V_rest = self.init_param(V_rest)
self.V_reset = self.init_param(V_reset)
@@ -917,7 +952,12 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike = stop_gradient(spike) if self.detach_spk else spike
- V += (self.V_reset - V) * spike
+ if self.spk_reset == 'soft':
+ V -= (self.V_th - self.V_reset) * spike
+ elif self.spk_reset == 'hard':
+ V += (self.V_reset - V) * spike
+ else:
+ raise ValueError
w += self.b * spike
else:
@@ -964,6 +1004,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
name: Optional[str] = None,
@@ -998,6 +1039,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
+ spk_reset=spk_reset,
init_var=False,
@@ -1055,7 +1097,12 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike) if self.detach_spk else spike
- V += (self.V_reset - V) * spike_no_grad
+ if self.spk_reset == 'soft':
+ V -= (self.V_th - self.V_reset) * spike_no_grad
+ elif self.spk_reset == 'hard':
+ V += (self.V_reset - V) * spike_no_grad
+ else:
+ raise ValueError
w += self.b * spike_no_grad
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
@@ -1180,6 +1227,7 @@ def __init__(
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
init_var: bool = True,
@@ -1203,7 +1251,8 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
- spk_type=spk_type)
+ spk_type=spk_type,
+ spk_reset=spk_reset)
# parameters
self.V_rest = self.init_param(V_rest)
self.V_reset = self.init_param(V_reset)
@@ -1245,7 +1294,12 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike = stop_gradient(spike) if self.detach_spk else spike
- V += (self.V_reset - V) * spike
+ if self.spk_reset == 'soft':
+ V -= (self.V_th - self.V_reset) * spike
+ elif self.spk_reset == 'hard':
+ V += (self.V_reset - V) * spike
+ else:
+ raise ValueError
else:
spike = V >= self.V_th
@@ -1280,6 +1334,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
name: Optional[str] = None,
@@ -1310,6 +1365,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
+ spk_reset=spk_reset,
init_var=False,
@@ -1362,7 +1418,12 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike) if self.detach_spk else spike
- V += (self.V_reset - V) * spike_no_grad
+ if self.spk_reset == 'soft':
+ V -= (self.V_th - self.V_reset) * spike_no_grad
+ elif self.spk_reset == 'hard':
+ V += (self.V_reset - V) * spike_no_grad
+ else:
+ raise ValueError
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
if self.ref_var:
@@ -1485,6 +1546,7 @@ def __init__(
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
init_var: bool = True,
@@ -1511,7 +1573,8 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
- spk_type=spk_type)
+ spk_type=spk_type,
+ spk_reset=spk_reset)
# parameters
self.V_rest = self.init_param(V_rest)
self.V_reset = self.init_param(V_reset)
@@ -1565,7 +1628,12 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike = stop_gradient(spike) if self.detach_spk else spike
- V += (self.V_reset - V) * spike
+ if self.spk_reset == 'soft':
+ V -= (self.V_th - self.V_reset) * spike
+ elif self.spk_reset == 'hard':
+ V += (self.V_reset - V) * spike
+ else:
+ raise ValueError
w += self.b * spike
else:
@@ -1611,6 +1679,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
name: Optional[str] = None,
@@ -1644,6 +1713,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
+ spk_reset=spk_reset,
init_var=False,
@@ -1700,7 +1770,12 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike) if self.detach_spk else spike
- V += (self.V_reset - V) * spike_no_grad
+ if self.spk_reset == 'soft':
+ V -= (self.V_th - self.V_reset) * spike_no_grad
+ elif self.spk_reset == 'hard':
+ V += (self.V_reset - V) * spike_no_grad
+ else:
+ raise ValueError
w += self.b * spike_no_grad
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
@@ -1839,6 +1914,7 @@ def __init__(
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
init_var: bool = True,
@@ -1872,7 +1948,8 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
- spk_type=spk_type)
+ spk_type=spk_type,
+ spk_reset=spk_reset, )
# parameters
self.V_rest = self.init_param(V_rest)
self.V_reset = self.init_param(V_reset)
@@ -1939,11 +2016,15 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike = stop_gradient(spike) if self.detach_spk else spike
- V += (self.V_reset - V) * spike
+ if self.spk_reset == 'soft':
+ V -= (self.V_th - self.V_reset) * spike
+ elif self.spk_reset == 'hard':
+ V += (self.V_reset - V) * spike
+ else:
+ raise ValueError
I1 += spike * (self.R1 * I1 + self.A1 - I1)
I2 += spike * (self.R2 * I2 + self.A2 - I2)
- reset_th = self.spk_fun(self.V_th_reset - V_th) * spike
- V_th += reset_th * (self.V_th_reset - V_th)
+ V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike
else:
spike = self.V_th <= V
@@ -1963,15 +2044,6 @@ def return_info(self):
class Gif(GifLTC):
- def dI1(self, I1, t):
- return - self.k1 * I1
-
- def dI2(self, I2, t):
- return - self.k2 * I2
-
- def dVth(self, V_th, t, V):
- return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf)
-
def dV(self, V, t, I1, I2, I):
return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau
@@ -1995,6 +2067,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
name: Optional[str] = None,
@@ -2035,6 +2108,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
+ spk_reset=spk_reset,
init_var=False,
@@ -2100,11 +2174,15 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike) if self.detach_spk else spike
- V += (self.V_reset - V) * spike
+ if self.spk_reset == 'soft':
+ V -= (self.V_th - self.V_reset) * spike_no_grad
+ elif self.spk_reset == 'hard':
+ V += (self.V_reset - V) * spike_no_grad
+ else:
+ raise ValueError
I1 += spike * (self.R1 * I1 + self.A1 - I1)
I2 += spike * (self.R2 * I2 + self.A2 - I2)
- reset_th = self.spk_fun(self.V_th_reset - V_th) * spike
- V_th += reset_th * (self.V_th_reset - V_th)
+ V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike_no_grad
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
if self.ref_var:
@@ -2130,22 +2208,9 @@ def update(self, x=None):
class GifRef(GifRefLTC):
- def dI1(self, I1, t):
- return - self.k1 * I1
-
- def dI2(self, I2, t):
- return - self.k2 * I2
-
- def dVth(self, V_th, t, V):
- return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf)
-
def dV(self, V, t, I1, I2, I):
return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau
- @property
- def derivative(self):
- return JointEq(self.dI1, self.dI2, self.dVth, self.dV)
-
def update(self, x=None):
x = 0. if x is None else x
for out in self.cur_inputs.values():
@@ -2153,10 +2218,10 @@ def update(self, x=None):
return super().update(x)
-Gif.__doc__ = GifLTC.__doc__ % ('')
-GifRefLTC.__doc__ = GifLTC.__doc__ % (ltc_doc)
-GifRef.__doc__ = GifLTC.__doc__ % ('')
-GifLTC.__doc__ = GifLTC.__doc__ % (ltc_doc)
+Gif.__doc__ = GifLTC.__doc__ % ''
+GifRefLTC.__doc__ = GifLTC.__doc__ % ltc_doc
+GifRef.__doc__ = GifLTC.__doc__ % ''
+GifLTC.__doc__ = GifLTC.__doc__ % ltc_doc
class IzhikevichLTC(GradNeuDyn):
@@ -2236,6 +2301,7 @@ def __init__(
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
init_var: bool = True,
@@ -2260,7 +2326,8 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
- spk_type=spk_type)
+ spk_type=spk_type,
+ spk_reset=spk_reset, )
# parameters
self.V_th = self.init_param(V_th)
self.a = self.init_param(a)
@@ -2314,7 +2381,7 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike = stop_gradient(spike) if self.detach_spk else spike
- V += spike * (self.c - self.V_th)
+ V += spike * (self.c - V)
u += spike * self.d
else:
@@ -2360,6 +2427,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
+ spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
name: Optional[str] = None,
@@ -2391,6 +2459,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
+ spk_reset=spk_reset,
init_var=False,
@@ -2445,7 +2514,7 @@ def update(self, x=None):
if isinstance(self.mode, bm.TrainingMode):
spike = self.spk_fun(V - self.V_th)
spike_no_grad = stop_gradient(spike) if self.detach_spk else spike
- V += spike * (self.c - self.V_th)
+ V += spike * (self.c - V)
u += spike * self.d
spike_ = spike_no_grad > 0.
# will be used in other place, like Delta Synapse, so stop its gradient
@@ -2487,7 +2556,7 @@ def update(self, x=None):
return super().update(x)
-Izhikevich.__doc__ = IzhikevichLTC.__doc__ % ('')
-IzhikevichRefLTC.__doc__ = IzhikevichLTC.__doc__ % (ltc_doc)
-IzhikevichRef.__doc__ = IzhikevichLTC.__doc__ % ('')
-IzhikevichLTC.__doc__ = IzhikevichLTC.__doc__ % (ltc_doc)
+Izhikevich.__doc__ = IzhikevichLTC.__doc__ % ''
+IzhikevichRefLTC.__doc__ = IzhikevichLTC.__doc__ % ltc_doc
+IzhikevichRef.__doc__ = IzhikevichLTC.__doc__ % ''
+IzhikevichLTC.__doc__ = IzhikevichLTC.__doc__ % ltc_doc
From b3e357f65934f4b8f709efba87cc0afe40f1211d Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 21 Jul 2023 20:04:57 +0800
Subject: [PATCH 065/326] update docs
---
brainpy/_src/dyn/_docs.py | 3 +
docs/quickstart/analysis.ipynb | 177 ++++++-----
docs/quickstart/training.ipynb | 554 ++++++++++++---------------------
3 files changed, 298 insertions(+), 436 deletions(-)
diff --git a/brainpy/_src/dyn/_docs.py b/brainpy/_src/dyn/_docs.py
index 823be6787..c2c75ffc9 100644
--- a/brainpy/_src/dyn/_docs.py
+++ b/brainpy/_src/dyn/_docs.py
@@ -11,6 +11,9 @@
detach_spk: bool.
method: str. The numerical integration method.
spk_type: The spike data type.
+ spk_reset: The way to reset the membrane potential when the neuron generates spikes.
+ This parameter only works when the computing mode is ``TrainingMode``.
+ It can be ``soft`` and ``hard``. Default is ``soft``.
'''.strip()
ref_doc = '''
diff --git a/docs/quickstart/analysis.ipynb b/docs/quickstart/analysis.ipynb
index 14b4a2fd6..02515a1aa 100644
--- a/docs/quickstart/analysis.ipynb
+++ b/docs/quickstart/analysis.ipynb
@@ -37,8 +37,8 @@
"id": "993ca509",
"metadata": {
"ExecuteTime": {
- "start_time": "2023-04-15T13:35:02.878689Z",
- "end_time": "2023-04-15T13:35:03.749844Z"
+ "end_time": "2023-07-21T08:53:38.185849800Z",
+ "start_time": "2023-07-21T08:53:37.076294Z"
}
},
"outputs": [],
@@ -57,7 +57,7 @@
"outputs": [
{
"data": {
- "text/plain": "'2.4.0'"
+ "text/plain": "'2.4.3'"
},
"execution_count": 2,
"metadata": {},
@@ -70,8 +70,8 @@
"metadata": {
"collapsed": false,
"ExecuteTime": {
- "start_time": "2023-04-15T13:35:03.749844Z",
- "end_time": "2023-04-15T13:35:03.764381Z"
+ "end_time": "2023-07-21T08:53:38.204162500Z",
+ "start_time": "2023-07-21T08:53:38.185849800Z"
}
}
},
@@ -119,13 +119,13 @@
"id": "8d6b11cb",
"metadata": {
"ExecuteTime": {
- "start_time": "2023-04-15T13:35:03.764381Z",
- "end_time": "2023-04-15T13:35:03.811297Z"
+ "end_time": "2023-07-21T08:53:38.240397100Z",
+ "start_time": "2023-07-21T08:53:38.205190900Z"
}
},
"outputs": [],
"source": [
- "expif = bp.neurons.ExpIF(1, delta_T=1.)"
+ "expif = bp.dyn.ExpIF(1, delta_T=1.)"
]
},
{
@@ -142,8 +142,8 @@
"id": "040b7004",
"metadata": {
"ExecuteTime": {
- "start_time": "2023-04-15T13:35:03.811297Z",
- "end_time": "2023-04-15T13:35:03.826935Z"
+ "end_time": "2023-07-21T08:53:38.271666300Z",
+ "start_time": "2023-07-21T08:53:38.240397100Z"
}
},
"outputs": [
@@ -165,7 +165,7 @@
"id": "09f5722a",
"metadata": {},
"source": [
- "After defining the model, we can use it for bifurcation analysis."
+ "After defining the model, we can use it for bifurcation analysis. Note that, the following analysis"
]
},
{
@@ -174,8 +174,8 @@
"id": "358060fb",
"metadata": {
"ExecuteTime": {
- "start_time": "2023-04-15T13:35:03.826935Z",
- "end_time": "2023-04-15T13:35:06.166395Z"
+ "end_time": "2023-07-21T08:53:39.762842400Z",
+ "start_time": "2023-07-21T08:53:38.271666300Z"
}
},
"outputs": [
@@ -189,7 +189,7 @@
{
"data": {
"text/plain": "