diff --git a/docs/source/conf.py b/docs/source/conf.py index 6f183e0b..9b1f6993 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,5 +1,7 @@ # Configuration file for the Sphinx documentation builder. +from sphinx_gallery.sorting import FileNameSortKey + # -- Project information project = 'PyJuice' @@ -25,6 +27,8 @@ 'examples_dirs': '../../examples', # other configuration options 'gallery_dirs': 'getting-started/tutorials', + # sort key + 'within_subsection_order': FileNameSortKey } intersphinx_mapping = { diff --git a/examples/01_train_pc.py b/examples/01_train_pc.py new file mode 100644 index 00000000..2bb113b9 --- /dev/null +++ b/examples/01_train_pc.py @@ -0,0 +1,163 @@ +""" +Train a PC +========== + +This tutorial demonstrates how to create a Hidden Chow-Liu Tree (https://arxiv.org/pdf/2106.02264.pdf) using `pyjuice.structures` and train the model with mini-batch EM and full-batch EM. + +For simplicity, we use the MNIST dataset as an example. +""" + +# sphinx_gallery_thumbnail_path = 'imgs/juice.png' + +# %% +# Load the MNIST Dataset +# ---------------------- + +import pyjuice as juice +import torch +import torchvision +import time +from torch.utils.data import TensorDataset, DataLoader +import pyjuice.nodes.distributions as dists + +train_dataset = torchvision.datasets.MNIST(root = "../data", train = True, download = True) +valid_dataset = torchvision.datasets.MNIST(root = "../data", train = False, download = True) + +train_data = train_dataset.data.reshape(60000, 28*28) +valid_data = valid_dataset.data.reshape(10000, 28*28) + +train_loader = DataLoader( + dataset = TensorDataset(train_data), + batch_size = 512, + shuffle = True, + drop_last = True +) +valid_loader = DataLoader( + dataset = TensorDataset(valid_data), + batch_size = 512, + shuffle = False, + drop_last = True +) + +# %% +# Create the PC +# ------------- + +# %% +# Let's create a HCLT PC with latent size 128. + +device = torch.device("cuda:0") + +# The data is required to construct the backbone Chow-Liu Tree structure for the HCLT +ns = juice.structures.HCLT( + train_data.float().to(device), + num_latents = 128 +) + +# %% +# We proceed to compile the PC with `pyjuice.compile`. + +pc = juice.compile(ns) + +# %% +# The `pc` is an instance of `torch.nn.Module`. So we can move it to the GPU as if it is a neural network. + +pc.to(device) + +# %% +# Train the PC +# ------------ + +# %% +# We start by defining the optimizer and scheduler. + +optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.1, method = "EM") +scheduler = juice.optim.CircuitScheduler( + optimizer, + method = "multi_linear", + lrs = [0.9, 0.1, 0.05], + milestone_steps = [0, len(train_loader) * 100, len(train_loader) * 350] +) + +# %% +# Optionally, we can leverage CUDA Graphs to hide the kernel launching overhead by doing a dry run. + +for batch in train_loader: + x = batch[0].to(device) + + lls = pc(x, record_cudagraph = True) + lls.mean().backward() + break + +# %% +# We are now ready for the training. Below is an example training loop for mini-batch EM. + +for epoch in range(1, 350+1): + t0 = time.time() + train_ll = 0.0 + for batch in train_loader: + x = batch[0].to(device) + + # Similar to PyTorch optimizers zeroling out the gradients, we zero out the parameter flows + optimizer.zero_grad() + + # Forward pass + lls = pc(x) + + # Backward pass + lls.mean().backward() + + train_ll += lls.mean().detach().cpu().numpy().item() + + # Perform a mini-batch EM step + optimizer.step() + scheduler.step() + + train_ll /= len(train_loader) + + t1 = time.time() + test_ll = 0.0 + for batch in valid_loader: + x = batch[0].to(pc.device) + lls = pc(x) + test_ll += lls.mean().detach().cpu().numpy().item() + + test_ll /= len(valid_loader) + t2 = time.time() + + print(f"[Epoch {epoch}/{350}][train LL: {train_ll:.2f}; val LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; val forward {t2-t1:.2f}] ") + +# %% +# Similarly, an example training loop for full-batch EM is given as follows. + +for epoch in range(1, 1+1): + t0 = time.time() + + # Manually zeroling out the flows + pc.init_param_flows(flows_memory = 0.0) + + train_ll = 0.0 + for batch in train_loader: + x = batch[0].to(device) + + # We only run the forward and the backward pass, and accumulate the flows throughout the epoch + lls = pc(x) + lls.mean().backward() + + train_ll += lls.mean().detach().cpu().numpy().item() + + # Set step size to 1.0 for full-batch EM + pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) + + train_ll /= len(train_loader) + + t1 = time.time() + test_ll = 0.0 + for batch in valid_loader: + x = batch[0].to(pc.device) + lls = pc(x) + test_ll += lls.mean().detach().cpu().numpy().item() + + test_ll /= len(valid_loader) + t2 = time.time() + print(f"[Epoch {epoch}/{1}][train LL: {train_ll:.2f}; val LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; val forward {t2-t1:.2f}] ") diff --git a/examples/02_construct_hmm.py b/examples/02_construct_hmm.py new file mode 100644 index 00000000..92c5798a --- /dev/null +++ b/examples/02_construct_hmm.py @@ -0,0 +1,60 @@ +""" +Construct an HMM +================ + +This tutorial demonstrates how to construct an HMM with PyJuice primitives `inputs`, `multiply`, and `summate`. +""" + +# sphinx_gallery_thumbnail_path = 'imgs/juice.png' + +import torch +import pyjuice as juice +import pyjuice.nodes.distributions as dists + +# %% +# We start with specifying the structural parameters of the HMM + +seq_length = 32 +num_latents = 2048 +num_emits = 4023 + +# %% +# An important parameter to be determined is the block size, which is crucial for PyJuice to compile efficient models. +# Specifically, we want the block size to be large enough so that PyJuice can leverage block-based parallelization. + +block_size = min(juice.utils.util.max_cdf_power_of_2(num_latents), 1024) + +# %% +# The number of node blocks is derived accordingly + +num_node_blocks = num_latents // block_size + +# %% +# We use the context manager `set_block_size` to set the block size of all PC nodes. +# In the following we assume `T = seq_length` and `K = num_latents` + +with juice.set_block_size(block_size): + # We begin by defining p(X_{T-1}|Z_{T-1}) for all k = 0...K-1 + ns_input = juice.inputs(seq_length - 1, num_node_blocks = num_node_blocks, + dist = dists.Categorical(num_cats = num_emits)) + + ns_sum = None + curr_zs = ns_input + for var in range(seq_length - 2, -1, -1): + # The emission probabilities p(X_{var}|Z_{var}=k) for all k = 0...K-1 + curr_xs = ns_input.duplicate(var, tie_params = True) + + # The transition probabilities p(Z_{var+1}|Z_{var}) + if ns_sum is None: + # Create both the structure and the transition probabilities + ns = juice.summate(curr_zs, num_node_blocks = num_node_blocks) + ns_sum = ns + else: + # Create only the structure and reuse the transition probabilities from `ns_sum` + ns = ns_sum.duplicate(curr_zs, tie_params=True) + + curr_zs = juice.multiply(curr_xs, ns) + + # The Initial probabilities p(Z_{0}) + ns = juice.summate(curr_zs, num_node_blocks = 1, block_size = 1) + diff --git a/examples/train_hmm.py b/examples/train_hmm.py deleted file mode 100644 index 83eee436..00000000 --- a/examples/train_hmm.py +++ /dev/null @@ -1,160 +0,0 @@ -""" -Example -======= - -dddd -""" - -# sphinx_gallery_thumbnail_path = 'imgs/juice.png' - -import pyjuice as juice -import torch -import torchvision -import time -import tqdm -from torch.utils.data import TensorDataset, DataLoader -import pyjuice.nodes.distributions as dists - - -def evaluate(pc, loader): - lls_total = 0.0 - for batch in loader: - x = batch[0].to(pc.device) - lls = pc(x) - lls_total += lls.mean().detach().cpu().numpy().item() - - lls_total /= len(loader) - return lls_total - - -def mini_batch_em_epoch(num_epochs, pc, optimizer, scheduler, train_loader, test_loader, device): - for epoch in range(num_epochs): - t0 = time.time() - train_ll = 0.0 - for batch in train_loader: - x = batch[0].to(device) - - optimizer.zero_grad() - - lls = pc(x) - lls.mean().backward() - - train_ll += lls.mean().detach().cpu().numpy().item() - - optimizer.step() - if scheduler is not None: - scheduler.step() - - train_ll /= len(train_loader) - - t1 = time.time() - test_ll = evaluate(pc, loader=test_loader) - t2 = time.time() - - print(f"[Epoch {epoch}/{num_epochs}][train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") - - -def full_batch_em_epoch(pc, train_loader, test_loader, device): - - pc.init_param_flows(flows_memory = 0.0) - - t0 = time.time() - train_ll = 0.0 - for batch in tqdm.tqdm(train_loader): - x = batch[0].to(device) - - lls = pc(x) - lls.mean().backward() - - train_ll += lls.mean().detach().cpu().numpy().item() - - pc.mini_batch_em(step_size = 1.0, pseudocount = 0.1) - - train_ll /= len(train_loader) - - t1 = time.time() - test_ll = evaluate(pc, loader=test_loader) - t2 = time.time() - print(f"[train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") - - -def homogenes_hmm(seq_length, num_latents, vocab_size): - - group_size = min(juice.utils.util.max_cdf_power_of_2(num_latents), 1024) - num_node_groups = num_latents // group_size - - with juice.set_group_size(group_size = group_size): - ns_input = juice.inputs(seq_length - 1, num_node_groups = num_node_groups, - dist = dists.Categorical(num_cats = vocab_size)) - - ns_sum = None - curr_zs = ns_input - for var in range(seq_length - 2, -1, -1): - curr_xs = ns_input.duplicate(var, tie_params = True) - - if ns_sum is None: - ns = juice.summate( - curr_zs, num_node_groups = num_node_groups) - ns_sum = ns - else: - ns = ns_sum.duplicate(curr_zs, tie_params=True) - - curr_zs = juice.multiply(curr_xs, ns) - - ns = juice.summate(curr_zs, num_node_groups = 1, group_size = 1) - - ns.init_parameters() - - return ns - - -def train_hmm(enable_cudagrph = True): - - device = torch.device("cuda:0") - - T = 32 - ns = homogenes_hmm(T, 8192, 4023) - - pc = juice.TensorCircuit(ns, max_tied_ns_per_parflow_group = 2) - pc.print_statistics() - - pc.to(device) - - data = torch.randint(0, 10000, (6400, T)) - - data_loader = DataLoader( - dataset = TensorDataset(data), - batch_size = 64, - shuffle = True, - drop_last = True - ) - - optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.0001) - - for batch in tqdm.tqdm(data_loader): - x = batch[0].to(device) - - lls = pc(x) - lls.mean().backward() - - break - - torch.cuda.synchronize() - t0 = time.time() - - for batch in tqdm.tqdm(data_loader): - x = batch[0].to(device) - - lls = pc(x) - lls.mean().backward() - - torch.cuda.synchronize() - t1 = time.time() - - print((t1-t0)/100*1000, "ms") - - # mini_batch_em_epoch(350, pc, optimizer, None, data_loader, data_loader, device) - - -if __name__ == "__main__": - train_hmm() diff --git a/examples/train_mnist_hclt.py b/examples/train_mnist_hclt.py deleted file mode 100644 index 42c0ce6e..00000000 --- a/examples/train_mnist_hclt.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Example -======= - -cccc -""" - -# sphinx_gallery_thumbnail_path = 'imgs/juice.png' - -import pyjuice as juice -import torch -import torchvision -import time -from torch.utils.data import TensorDataset, DataLoader -import pyjuice.nodes.distributions as dists - - -def evaluate(pc, loader): - lls_total = 0.0 - for batch in loader: - x = batch[0].to(pc.device) - lls = pc(x) - lls_total += lls.mean().detach().cpu().numpy().item() - - lls_total /= len(loader) - return lls_total - - -def mini_batch_em_epoch(num_epochs, pc, optimizer, scheduler, train_loader, test_loader, device): - for epoch in range(num_epochs): - t0 = time.time() - train_ll = 0.0 - for batch in train_loader: - x = batch[0].to(device) - - optimizer.zero_grad() - - lls = pc(x) - lls.mean().backward() - - train_ll += lls.mean().detach().cpu().numpy().item() - - optimizer.step() - scheduler.step() - - train_ll /= len(train_loader) - - t1 = time.time() - test_ll = evaluate(pc, loader=test_loader) - t2 = time.time() - - print(f"[Epoch {epoch}/{num_epochs}][train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") - - -def full_batch_em_epoch(pc, train_loader, test_loader, device): - t0 = time.time() - train_ll = 0.0 - for batch in train_loader: - x = batch[0].to(device) - - lls = pc(x) - lls.mean().backward() - - train_ll += lls.mean().detach().cpu().numpy().item() - - pc.mini_batch_em(step_size = 1.0, pseudocount = 0.1) - - train_ll /= len(train_loader) - - t1 = time.time() - test_ll = evaluate(pc, loader=test_loader) - t2 = time.time() - print(f"[train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") - - -def train_mnist_hclt(enable_cudagrph = True): - - device = torch.device("cuda:0") - - train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) - test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True) - - train_data = train_dataset.data.reshape(60000, 28*28) - test_data = test_dataset.data.reshape(10000, 28*28) - - num_features = train_data.size(1) - - train_loader = DataLoader( - dataset = TensorDataset(train_data), - batch_size = 512, - shuffle = True, - drop_last = True - ) - test_loader = DataLoader( - dataset = TensorDataset(test_data), - batch_size = 512, - shuffle = False, - drop_last = True - ) - - ns = juice.structures.HCLT( - train_data.float().to(device), - num_bins = 32, - sigma = 0.5 / 32, - num_latents = 128, - chunk_size = 32 - ) - pc = juice.TensorCircuit(ns) - - pc.to(device) - - optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.1) - scheduler = juice.optim.CircuitScheduler( - optimizer, - method = "multi_linear", - lrs = [0.9, 0.1, 0.05], - milestone_steps = [0, len(train_loader) * 100, len(train_loader) * 350] - ) - - if enable_cudagrph: - # Dry run to record CUDA graphs - for batch in train_loader: - x = batch[0].to(device) - - lls = pc(x, record_cudagraph = True) - lls.mean().backward() - break - - mini_batch_em_epoch(350, pc, optimizer, scheduler, train_loader, test_loader, device) - full_batch_em_epoch(pc, train_loader, test_loader, device) - - -if __name__ == "__main__": - train_mnist_hclt() diff --git a/examples/train_mnist_pd.py b/examples/train_mnist_pd.py deleted file mode 100644 index 9b3e10af..00000000 --- a/examples/train_mnist_pd.py +++ /dev/null @@ -1,120 +0,0 @@ -""" -Example -======= - -bbbb -""" - -# sphinx_gallery_thumbnail_path = 'imgs/juice.png' - -import pyjuice as juice -import torch -import torchvision -import time -from torch.utils.data import TensorDataset, DataLoader - - -def evaluate(pc, loader): - lls_total = 0.0 - for batch in loader: - x = batch[0].to(pc.device) - lls = pc(x) - lls_total += lls.mean().detach().cpu().numpy().item() - - lls_total /= len(loader) - return lls_total - - -def mini_batch_em_epoch(num_epochs, pc, optimizer, scheduler, train_loader, test_loader, device): - for epoch in range(num_epochs): - t0 = time.time() - train_ll = 0.0 - for batch in train_loader: - x = batch[0].to(device) - - optimizer.zero_grad() - - lls = pc(x) - lls.mean().backward() - - train_ll += lls.mean().detach().cpu().numpy().item() - - optimizer.step() - if scheduler is not None: - scheduler.step() - - train_ll /= len(train_loader) - - t1 = time.time() - test_ll = evaluate(pc, loader=test_loader) - t2 = time.time() - - print(f"[Epoch {epoch}/{num_epochs}][train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") - - -def full_batch_em_epoch(pc, train_loader, test_loader, device): - with torch.no_grad(): - t0 = time.time() - train_ll = 0.0 - for batch in train_loader: - x = batch[0].to(device) - - lls = pc(x) - pc.backward(x, flows_memory = 1.0) - - train_ll += lls.mean().detach().cpu().numpy().item() - - pc.mini_batch_em(step_size = 1.0, pseudocount = 0.1) - - train_ll /= len(train_loader) - - t1 = time.time() - test_ll = evaluate(pc, loader=test_loader) - t2 = time.time() - print(f"[train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") - - -def train_mnist_pd(): - - device = torch.device("cuda:0") - - train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) - test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True) - - train_data = train_dataset.data.reshape(60000, 28*28) - test_data = test_dataset.data.reshape(10000, 28*28) - - num_features = train_data.size(1) - - train_loader = DataLoader( - dataset = TensorDataset(train_data), - batch_size = 512, - shuffle = True, - drop_last = True - ) - test_loader = DataLoader( - dataset = TensorDataset(test_data), - batch_size = 512, - shuffle = False, - drop_last = True - ) - - ns = juice.structures.PD( - data_shape = (28, 28), - num_latents = 512, - split_intervals = (4, 4), - structure_type = "sum_dominated" - ) - pc = juice.TensorCircuit(ns) - pc.print_statistics() - - pc.to(device) - - optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.0001) - - mini_batch_em_epoch(350, pc, optimizer, None, train_loader, test_loader, device) - full_batch_em_epoch(pc, train_loader, test_loader, device) - - -if __name__ == "__main__": - train_mnist_pd() diff --git a/examples/train_mnist_pdhclt.py b/examples/train_mnist_pdhclt.py deleted file mode 100644 index 0187e20a..00000000 --- a/examples/train_mnist_pdhclt.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -Example -======= - -aaaa -""" - -# sphinx_gallery_thumbnail_path = 'imgs/juice.png' - -import pyjuice as juice -import torch -import torchvision -import time -from torch.utils.data import TensorDataset, DataLoader - - -def evaluate(pc, loader): - lls_total = 0.0 - for batch in loader: - x = batch[0].to(pc.device) - lls = pc(x) - lls_total += lls.mean().detach().cpu().numpy().item() - - lls_total /= len(loader) - return lls_total - - -def mini_batch_em_epoch(num_epochs, pc, optimizer, scheduler, train_loader, test_loader, device): - for epoch in range(num_epochs): - t0 = time.time() - train_ll = 0.0 - for batch in train_loader: - x = batch[0].to(device) - - optimizer.zero_grad() - - lls = pc(x) - lls.mean().backward() - - train_ll += lls.mean().detach().cpu().numpy().item() - - optimizer.step() - if scheduler is not None: - scheduler.step() - - train_ll /= len(train_loader) - - t1 = time.time() - test_ll = evaluate(pc, loader=test_loader) - t2 = time.time() - - print(f"[Epoch {epoch}/{num_epochs}][train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") - - -def full_batch_em_epoch(pc, train_loader, test_loader, device): - with torch.no_grad(): - t0 = time.time() - train_ll = 0.0 - for batch in train_loader: - x = batch[0].to(device) - - lls = pc(x) - pc.backward(x, flows_memory = 1.0) - - train_ll += lls.mean().detach().cpu().numpy().item() - - pc.mini_batch_em(step_size = 1.0, pseudocount = 0.1) - - train_ll /= len(train_loader) - - t1 = time.time() - test_ll = evaluate(pc, loader=test_loader) - t2 = time.time() - print(f"[train LL: {train_ll:.2f}; test LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; test forward {t2-t1:.2f}] ") - - -def pd_hclt_test(): - - device = torch.device("cuda:0") - - train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True) - test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True) - - train_data = train_dataset.data.reshape(60000, 28*28) - test_data = test_dataset.data.reshape(10000, 28*28) - - num_features = train_data.size(1) - - train_loader = DataLoader( - dataset = TensorDataset(train_data), - batch_size = 512, - shuffle = True, - drop_last = True - ) - test_loader = DataLoader( - dataset = TensorDataset(test_data), - batch_size = 512, - shuffle = False, - drop_last = True - ) - - ns = juice.structures.PDHCLT( - train_data.cuda(), - data_shape = (28, 28), - num_latents = 128, - split_intervals = (4, 4), - structure_type = "sum_dominated" - ) - pc = juice.TensorCircuit(ns) - - pc.to(device) - - optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.1) - scheduler = juice.optim.CircuitScheduler( - optimizer, - method = "multi_linear", - lrs = [0.9, 0.1, 0.05], - milestone_steps = [0, len(train_loader) * 100, len(train_loader) * 350] - ) - - pc.print_statistics() - - mini_batch_em_epoch(350, pc, optimizer, None, train_loader, test_loader, device) - full_batch_em_epoch(pc, train_loader, test_loader, device) - - -if __name__ == "__main__": - torch.manual_seed(2391) - pd_hclt_test() diff --git a/src/pyjuice/structures/hmm.py b/src/pyjuice/structures/hmm.py index d69acbc7..d897d936 100644 --- a/src/pyjuice/structures/hmm.py +++ b/src/pyjuice/structures/hmm.py @@ -47,13 +47,13 @@ def HMM(seq_length: int, num_latents: int, num_emits: int, homogeneous: bool = T :type homogeneous: bool """ - group_size = min(max_cdf_power_of_2(num_latents), 1024) - num_node_groups = num_latents // group_size + block_size = min(max_cdf_power_of_2(num_latents), 1024) + num_node_blocks = num_latents // block_size - with juice.set_group_size(group_size = group_size): + with juice.set_block_size(block_size = block_size): ns_input = inputs( - seq_length - 1, num_node_groups = num_node_groups, + seq_length - 1, num_node_blocks = num_node_blocks, dist = Categorical(num_cats = num_emits) ) @@ -63,13 +63,13 @@ def HMM(seq_length: int, num_latents: int, num_emits: int, homogeneous: bool = T curr_xs = ns_input.duplicate(var, tie_params = homogeneous) if ns_sum is None: - ns = summate(curr_zs, num_node_groups = num_node_groups) + ns = summate(curr_zs, num_node_blocks = num_node_blocks) ns_sum = ns else: ns = ns_sum.duplicate(curr_zs, tie_params = homogeneous) curr_zs = multiply(curr_xs, ns) - ns = summate(curr_zs, num_node_groups = 1, group_size = 1) + ns = summate(curr_zs, num_node_blocks = 1, block_size = 1) return ns