diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4bb8d5f --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +venv/* +data/* +runs/* +runs_*/* +.idea/* +*.gv* +*/__pycache__/* +*.pyc +wandb/* diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..574add2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Serhii Kostiuk + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..fec6435 --- /dev/null +++ b/README.md @@ -0,0 +1,104 @@ +# Learnable Extended Activation Function (LEAF) for Deep Neural Networks + +Implementation of the experiment as published in the paper "Learnable Extended +Activation Function for Deep Neural Networks" by +Yevgeniy Bodyanskiy and Serhii Kostiuk. + +## Running experiments + +1. NVIDIA GPU recommended with at least 2 GiB of VRAM. +2. Install the requirements from `requirements.txt`. +3. Set `CUBLAS_WORKSPACE_CONFIG=:4096:8` in the environment variables. +4. Use the root of this repository as the current directory. +5. Add the current directory to `PYTHONPATH` so it can find the modules + +This repository contains a wrapper script that sets all the required +environment variables: [run_experiment.sh](./run_experiment.sh). Use the bash shell to +execute the experiment using the wrapper script: + +Example: + +```shell +user@host:~/repo_path$ ./run_experiment.sh experiments/train_new_base.py +``` + +## Reproducing the results from the paper + +1. Training LeNet-5 and KerasNet networks with linear units from scratch: + + ```shell + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --opt adam --end_ep 100 --acts all_lus base + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --opt adam --end_ep 100 --acts all_lus ahaf + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --opt adam --end_ep 100 --acts all_lus ahaf --dspu4 + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --opt adam --end_ep 100 --acts all_lus leaf --p24sl + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --opt adam --end_ep 100 --acts all_lus leaf --p24sl --dspu4 + ``` + +2. Training LeNet-5 and KerasNet networks with linear units from scratch: + + ```shell + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --opt adam --end_ep 100 --acts all_bfs base + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --opt adam --end_ep 100 --acts all_bfs ahaf + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --opt adam --end_ep 100 --acts all_bfs ahaf --dspu4 + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --opt adam --end_ep 100 --acts all_bfs leaf --p24sl + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --opt adam --end_ep 100 --acts all_bfs leaf --p24sl --dspu4 + ``` + +3. On stability of LEAF-as-ReLU: + + ```shell + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --end_ep 100 --acts ReLU --net KerasNet --ds CIFAR-10 \ + --opt adam leaf + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --end_ep 100 --acts ReLU --net KerasNet --ds CIFAR-10 \ + --opt adam leaf --p24sl + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --end_ep 100 --acts ReLU --net KerasNet --ds CIFAR-10 \ + --opt rmsprop ahaf + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --end_ep 100 --acts ReLU --net KerasNet --ds CIFAR-10 \ + --opt rmsprop leaf + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --end_ep 100 --acts ReLU --net KerasNet --ds CIFAR-10 \ + --opt rmsprop leaf --p24sl + ``` + + Add the `--wandb` parameter to log the training process to Weights and + Biases. Weights and Biases provides visualization of the parameter values and + the gradient values during training. + +4. On the effect of synaptic weights initialization. Execute all commands below + once per each of the seed values: + + ```shell + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --seed 7823 --opt adam --ds CIFAR-10 base + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --seed 7823 --opt adam --ds CIFAR-10 ahaf + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --seed 7823 --opt adam --ds CIFAR-10 ahaf --dspu4 + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --seed 7823 --opt adam --ds CIFAR-10 leaf --p24sl + user@host:~/repo_path$ ./run_experiment.sh experiments/train_individual.py \ + --seed 7823 --opt adam --ds CIFAR-10 leaf --p24sl --dspu4 + ``` + + Seed values to evaluate: 42, 100, 128, 1999, 7823. + +## Visualization of experiment results + +Use tools from the [post_experiment](./post_experiment) directory to visualize +training process, create the training result summary tables and visualize the +activation function form for LEAF/AHAF compared to the corresponding base +activations. diff --git a/adaptive_afs/__init__.py b/adaptive_afs/__init__.py new file mode 100644 index 0000000..735c05e --- /dev/null +++ b/adaptive_afs/__init__.py @@ -0,0 +1,3 @@ +from .cont import AHAF, LEAF +from .fuzzy import FNeuronAct +from .af_builder import af_build, AfDefinition diff --git a/adaptive_afs/af_builder/__init__.py b/adaptive_afs/af_builder/__init__.py new file mode 100644 index 0000000..a751b50 --- /dev/null +++ b/adaptive_afs/af_builder/__init__.py @@ -0,0 +1,2 @@ +from .af_definition import AfDefinition +from .af_build_m import af_build diff --git a/adaptive_afs/af_builder/af_build_fuzzy.py b/adaptive_afs/af_builder/af_build_fuzzy.py new file mode 100644 index 0000000..df9ce70 --- /dev/null +++ b/adaptive_afs/af_builder/af_build_fuzzy.py @@ -0,0 +1,15 @@ +from typing import Tuple + +from ..fuzzy.f_neuron_act import FNeuronAct + + +def af_build_fuzzy( + af_base: str, af_start: float, af_end: float, n_segments: int, + in_dims: Tuple[int, ...] = (1,) +) -> FNeuronAct: + init_f = FNeuronAct.get_init_f_by_name(af_base) + + return FNeuronAct( + af_start, af_end, n_segments, + init_f=init_f, input_dim=in_dims + ) diff --git a/adaptive_afs/af_builder/af_build_m.py b/adaptive_afs/af_builder/af_build_m.py new file mode 100644 index 0000000..a5cdd2b --- /dev/null +++ b/adaptive_afs/af_builder/af_build_m.py @@ -0,0 +1,42 @@ +from typing import Union, Optional, Tuple + +from .af_definition import AfDefinition +from ..cont import AHAF, LEAF +from ..fuzzy import FNeuronAct +from .af_build_fuzzy import af_build_fuzzy +from .af_build_traditional import af_build_traditional, AfTraditional + +ActivationFunction = Union[ + AfTraditional, AHAF, LEAF, FNeuronAct +] + + +def af_build( + d: AfDefinition, in_dims: Optional[Tuple[int, ...]] = None +) -> ActivationFunction: + if in_dims is None: + # Has sense only for adaptive activations + in_dims = (1,) + + if d.af_type == AfDefinition.AfType.TRAD: + if d.interval is None: + return af_build_traditional(d.af_base) + else: + return af_build_traditional( + d.af_base, + d.interval.start, + d.interval.end + ) + elif d.af_type == AfDefinition.AfType.ADA_AHAF: + return AHAF(size=in_dims, init_as=d.af_base) + elif d.af_type == AfDefinition.AfType.ADA_LEAF: + return LEAF(size=in_dims, init_as=d.af_base) + elif d.af_type == AfDefinition.AfType.ADA_FUZZ: + return af_build_fuzzy( + d.af_base, + d.interval.start, d.interval.end, + d.interval.n_segments, + in_dims + ) + else: + raise NotImplementedError("The requested AF type is not supported") diff --git a/adaptive_afs/af_builder/af_build_traditional.py b/adaptive_afs/af_builder/af_build_traditional.py new file mode 100644 index 0000000..c5a8470 --- /dev/null +++ b/adaptive_afs/af_builder/af_build_traditional.py @@ -0,0 +1,38 @@ +from typing import Callable, Optional + +import torch +import torch.nn +import torch.nn.functional + +from torch import Tensor + +from ..trad import silu_manual + + +AfTraditional = Callable[[Tensor], Tensor] + + +def af_build_traditional( + af_name: str, + min_val: Optional[float] = None, max_val: Optional[float] = None +) -> AfTraditional: + if af_name == "ReLU": + return torch.relu + elif af_name == "SiLU": + # Using a custom SiLU implementation to exactly follow AAF alternatives + return silu_manual + elif af_name == "Tanh": + return torch.tanh + elif af_name == "HardTanh": + if min_val is None or max_val is None: + return torch.nn.Hardtanh() + else: + return torch.nn.Hardtanh(min_val, max_val) + elif af_name == "Sigmoid": + return torch.sigmoid + elif af_name == "HardSigmoid": + return torch.nn.functional.hardsigmoid + else: + raise NotImplementedError( + "The requested traditional activation function is not supported" + ) diff --git a/adaptive_afs/af_builder/af_definition.py b/adaptive_afs/af_builder/af_definition.py new file mode 100644 index 0000000..b3a8a6b --- /dev/null +++ b/adaptive_afs/af_builder/af_definition.py @@ -0,0 +1,24 @@ +from enum import Enum +from typing import Optional, Tuple + + +class AfDefinition: + class AfType(Enum): + TRAD = 0 + ADA_AHAF = 1 + ADA_FUZZ = 2 + ADA_LEAF = 3 + + class AfInterval: + def __init__(self, start: float, end: float, n_segments: int = 0): + self.start = start + self.end = end + self.n_segments = n_segments + + def __init__( + self, af_base: str = "ReLU", af_type: AfType = AfType.TRAD, + af_interval: Optional[AfInterval] = None + ): + self.af_base = af_base + self.af_type = af_type + self.interval = af_interval diff --git a/adaptive_afs/cont/__init__.py b/adaptive_afs/cont/__init__.py new file mode 100644 index 0000000..81d2ced --- /dev/null +++ b/adaptive_afs/cont/__init__.py @@ -0,0 +1,2 @@ +from .ahaf import AHAF +from .leaf import LEAF diff --git a/adaptive_afs/cont/ahaf.py b/adaptive_afs/cont/ahaf.py new file mode 100644 index 0000000..9ccbeb4 --- /dev/null +++ b/adaptive_afs/cont/ahaf.py @@ -0,0 +1,82 @@ +from typing import Tuple, Any, Sequence + +import torch +from torch.autograd.function import FunctionCtx + + +class _AHAF(torch.autograd.Function): + @staticmethod + def forward(u, beta, gamma) -> Any: + y = (beta * u) * torch.sigmoid(gamma * u) + return y + + @staticmethod + def setup_context(ctx: FunctionCtx, inputs: Sequence[Any], output: Any) -> None: + u, beta, gamma = inputs + ctx.save_for_backward(u, beta, gamma) + + @staticmethod + def backward(ctx: FunctionCtx, grad_output: Any) -> Any: + u, beta, gamma = ctx.saved_tensors + grad_u = grad_beta = grad_gamma = None + + gamma_u = gamma * u + sig_gamma_u = torch.sigmoid(gamma_u) + + if ctx.needs_input_grad[0]: + grad_u = grad_output.mul( + (beta * sig_gamma_u) + * + (1 + u * gamma * (1 - sig_gamma_u)) + ) + if ctx.needs_input_grad[1]: + grad_beta = grad_output.mul(u * sig_gamma_u) + if ctx.needs_input_grad[2]: + grad_gamma = grad_output.mul( + (beta * u) + * sig_gamma_u + * torch.sigmoid(-gamma_u) + * u + ) + + return grad_u, grad_beta, grad_gamma + + +def _ahaf(u, beta, gamma): + return _AHAF.apply(u, beta, gamma) + + +class AHAF(torch.nn.Module): + def __init__(self, *, size: Tuple[int, ...] = (1,), init_as: str = 'ReLU'): + super(AHAF, self).__init__() + + if init_as == 'ReLU': + self.gamma = torch.nn.Parameter(torch.ones(*size) * (2.0**16)) + self.beta = torch.nn.Parameter(torch.ones(*size)) + elif init_as == 'SiLU': + self.gamma = torch.nn.Parameter(torch.ones(*size)) + self.beta = torch.nn.Parameter(torch.ones(*size)) + elif init_as == 'CUSTOM': + self.gamma = torch.nn.Parameter(torch.ones(*size)*10) + self.beta = torch.nn.Parameter(torch.ones(*size)) + else: + raise ValueError("Invalid initialization mode [{}]".format(init_as)) + + @staticmethod + def _get_sample_value(t: torch.Tensor) -> float: + size = t.size() + + for _ in size: + t = t[0] + + return t.item() + + def forward(self, inputs): + return _ahaf(inputs, self.beta, self.gamma) + + def __repr__(self): + return "AHAF(size={},gamma={}, beta={})".format( + tuple(self.gamma.size()), + self._get_sample_value(self.gamma), + self._get_sample_value(self.beta) + ) diff --git a/adaptive_afs/cont/leaf.py b/adaptive_afs/cont/leaf.py new file mode 100644 index 0000000..24109ab --- /dev/null +++ b/adaptive_afs/cont/leaf.py @@ -0,0 +1,108 @@ +from typing import Tuple, Any, Sequence + +import torch +from torch.autograd.function import FunctionCtx + + +class _LEAF(torch.autograd.Function): + @staticmethod + def forward(u, p1, p2, p3, p4) -> Any: + y = (p1 * u + p2) * torch.sigmoid(p3 * u) + p4 + return y + + @staticmethod + def setup_context(ctx: FunctionCtx, inputs: Sequence[Any], output: Any) -> None: + u, p1, p2, p3, p4 = inputs + ctx.save_for_backward(u, p1, p2, p3, p4) + + @staticmethod + def backward(ctx: FunctionCtx, grad_output: Any) -> Any: + u, p1, p2, p3, p4 = ctx.saved_tensors + grad_u = grad_p1 = grad_p2 = grad_p3 = grad_p4 = None + + p3_u = p3 * u + sig_p3_u = torch.sigmoid(p3_u) + + if ctx.needs_input_grad[0]: + grad_u = grad_output.mul( + (p1 * sig_p3_u) + + + (p1 * u + p2) + * sig_p3_u + * torch.sigmoid(-p3_u) + * p3 + ) + if ctx.needs_input_grad[1]: + grad_p1 = grad_output.mul(u * sig_p3_u) + if ctx.needs_input_grad[2]: + grad_p2 = grad_output.mul(sig_p3_u) + if ctx.needs_input_grad[3]: + grad_p3 = grad_output.mul( + (p1 * u + p2) + * sig_p3_u + * torch.sigmoid(-p3_u) + * u + ) + if ctx.needs_input_grad[4]: + grad_p4 = grad_output.mul(torch.ones_like(u)) + + return grad_u, grad_p1, grad_p2, grad_p3, grad_p4 + + +def _leaf(u, p1, p2, p3, p4): + return _LEAF.apply(u, p1, p2, p3, p4) + + +class LEAF(torch.nn.Module): + def __init__(self, *, size: Tuple[int, ...] = (1,), init_as: str = 'ReLU'): + super(LEAF, self).__init__() + + if init_as == 'ReLU': + self.p1 = torch.nn.Parameter(torch.ones(*size)) + self.p2 = torch.nn.Parameter(torch.zeros(*size)) + self.p3 = torch.nn.Parameter(torch.ones(*size) * (2.0**16)) + self.p4 = torch.nn.Parameter(torch.zeros(*size)) + elif init_as == 'SiLU': + self.p1 = torch.nn.Parameter(torch.ones(*size)) + self.p2 = torch.nn.Parameter(torch.zeros(*size)) + self.p3 = torch.nn.Parameter(torch.ones(*size)) + self.p4 = torch.nn.Parameter(torch.zeros(*size)) + elif init_as == 'CUSTOM': + self.p1 = torch.nn.Parameter(torch.ones(*size)) + self.p2 = torch.nn.Parameter(torch.zeros(*size)) + self.p3 = torch.nn.Parameter(torch.ones(*size) * 10) + self.p4 = torch.nn.Parameter(torch.zeros(*size)) + elif init_as == 'Tanh': + self.p1 = torch.nn.Parameter(torch.zeros(*size)) + self.p2 = torch.nn.Parameter(torch.ones(*size) * 2.0) + self.p3 = torch.nn.Parameter(torch.ones(*size) * 2.0) + self.p4 = torch.nn.Parameter(torch.ones(*size) * -1.0) + elif init_as == 'Sigmoid': + self.p1 = torch.nn.Parameter(torch.zeros(*size)) + self.p2 = torch.nn.Parameter(torch.ones(*size)) + self.p3 = torch.nn.Parameter(torch.ones(*size)) + self.p4 = torch.nn.Parameter(torch.zeros(*size)) + else: + raise ValueError("Invalid initialization mode [{}]".format(init_as)) + + @staticmethod + def _get_sample_value(t: torch.Tensor) -> float: + size = t.size() + + for _ in size: + t = t[0] + + return t.item() + + def forward(self, x): + y = _leaf(x, self.p1, self.p2, self.p3, self.p4) + return y + + def __repr__(self): + return "LEAF(size={},p1={},p2={},p3={},p4={})".format( + tuple(self.p3.size()), + self._get_sample_value(self.p1), + self._get_sample_value(self.p2), + self._get_sample_value(self.p3), + self._get_sample_value(self.p4) + ) diff --git a/adaptive_afs/fuzzy/__init__.py b/adaptive_afs/fuzzy/__init__.py new file mode 100644 index 0000000..88709f0 --- /dev/null +++ b/adaptive_afs/fuzzy/__init__.py @@ -0,0 +1 @@ +from .f_neuron_act import FNeuronAct diff --git a/adaptive_afs/fuzzy/f_neuron_act.py b/adaptive_afs/fuzzy/f_neuron_act.py new file mode 100644 index 0000000..dd90161 --- /dev/null +++ b/adaptive_afs/fuzzy/f_neuron_act.py @@ -0,0 +1,314 @@ +from typing import Callable, Tuple + +import torch.nn +from torch.nn import functional as F + +from .fmf_triangular import TriangularMembF +from .fmf_ramp_left import LeftRampMembF +from .fmf_ramp_right import RightRampMembF + + +class FNeuronAct(torch.nn.Module): + @staticmethod + def ramp_init( + count: int, + input_dim: Tuple[int, ...] = (1,), + in_range: Tuple[float, float] = (-1.0, +1.0) + ) -> torch.Tensor: + """ + Initialize member function weights to create a ramp function + from -1.0 to +1.0. + + :param count: number of member functions. + :param input_dim: input data dimensions: + - scalar ``(1,)`` - by default; + - vector ``(x,)``; + - matrix ``(x,y)``; + - multi-channel image: ``(z,x,y)`` where ``z`` is the number of + channels. + :param in_range: the range of input values that should be covered, + ignored by this implementation. + :return: initialized tensor of size ``(z,x,y,count)`` where dimensions + ``z`` and ``y`` are optional. + """ + low = -1.0 + high = +1.0 + range_ = high - low + step = range_ / (count + 1) + eps = step / 100 + sample = torch.arange(low, high + eps, step) + result = torch.empty(*input_dim, len(sample)) + return result.copy_(sample) + + @staticmethod + def inv_ramp_init( + count: int, + input_dim: Tuple[int, ...] = (1,), + in_range: Tuple[float, float] = (-1.0, +1.0) + ) -> torch.Tensor: + """ + Initialize member function weights to create an inverse ramp function + from +1.0 to -1.0. + + :param count: number of member functions. + :param input_dim: input data dimensions: + - scalar ``(1,)`` - by default; + - vector ``(x,)``; + - matrix ``(x,y)``; + - multi-channel image: ``(z,x,y)`` where ``z`` is the number of + channels. + :param in_range: the range of input values that should be covered, + ignored by this implementation. + :return: initialized tensor of size ``(z,x,y,count)`` where dimensions + ``z`` and ``y`` are optional. + """ + return - FNeuronAct.ramp_init(count, input_dim, in_range) + + @staticmethod + def random_init( + count: int, + input_dim: Tuple[int, ...] = (1,), + in_range: Tuple[float, float] = (-1.0, +1.0) + ) -> torch.Tensor: + """ + Random weights initialization, ranging from -1.0 to +1.0. + + :param count: number of member functions. + :param input_dim: input data dimensions: + - scalar ``(1,)`` - by default; + - vector ``(x,)``; + - matrix ``(x,y)``; + - multi-channel image: ``(z,x,y)`` where ``z`` is the number of + channels. + :param in_range: the range of input values that should be covered, + ignored by this implementation. + :return: initialized tensor of size ``(z,x,y,count)`` where dimensions + ``z`` and ``y`` are optional. + """ + low = -1.0 + high = +1.0 + out_range = high - low + return low + torch.rand(*input_dim, count + 2) * out_range + + @staticmethod + def all_hot_init( + count: int, + input_dim: Tuple[int, ...] = (1,), + in_range: Tuple[float, float] = (-1.0, +1.0) + ) -> torch.Tensor: + """ + Constant weights initialization, all membership functions active with + the same weight of 1.0. + + :param count: number of member functions. + :param input_dim: input data dimensions: + - scalar ``(1,)`` - by default; + - vector ``(x,)``; + - matrix ``(x,y)``; + - multi-channel image: ``(z,x,y)`` where ``z`` is the number of + channels. + :param in_range: the range of input values that should be covered, + ignored by this implementation. + :return: initialized tensor of size ``(z,x,y,count)`` where dimensions + ``z`` and ``y`` are optional. + """ + return torch.ones(*input_dim, count + 2) + + @staticmethod + def _init_as_function( + orig_function: Callable[[torch.Tensor], torch.Tensor], + count: int, + input_dim: Tuple[int, ...] = (1,), + left: float = -1.0, + right: float = +1.0 + ) -> torch.Tensor: + range_ = right - left + step = range_ / (count + 1) + eps = step / 100 + sample = torch.arange(left, right + eps, step) + sample = orig_function(sample) + result = torch.empty(*input_dim, len(sample)) + return result.copy_(sample) + + @staticmethod + def tanh_init( + count: int, + input_dim: Tuple[int, ...] = (1,), + in_range: Tuple[float, float] = (-1.0, +1.0) + ) -> torch.Tensor: + """ + Weights initialization that creates an activation function that roughly + corresponds to an approximated tanh function. + + NOTE: The minimum recommended input range is from -5.0 to +5.0 to + follow tanh with less than 10^-3 error outside the range. + + :param count: number of member functions. + :param input_dim: input data dimensions: + - scalar ``(1,)`` - by default; + - vector ``(x,)``; + - matrix ``(x,y)``; + - multi-channel image: ``(z,x,y)`` where ``z`` is the number of + channels. + :param in_range: the range of input values that should be covered. + :return: initialized tensor of size ``(z,x,y,count)`` where dimensions + ``z`` and ``y`` are optional. + """ + left, right = in_range + return FNeuronAct._init_as_function( + torch.tanh, count, input_dim, left, right + ) + + @staticmethod + def sigmoid_init( + count: int, + input_dim: Tuple[int, ...] = (1,), + in_range: Tuple[float, float] = (-1.0, +1.0) + ) -> torch.Tensor: + """ + Weights initialization that creates an activation function that roughly + corresponds to an approximated sigmoid function. + + NOTE: The minimum recommended input range is from -10.0 to +10.0 to + follow sigmoid with less than 10^-3 error outside the range. + + :param count: number of member functions. + :param input_dim: input data dimensions: + - scalar ``(1,)`` - by default; + - vector ``(x,)``; + - matrix ``(x,y)``; + - multi-channel image: ``(z,x,y)`` where ``z`` is the number of + channels. + :param in_range: the range of input values that should be covered. + :return: initialized tensor of size ``(z,x,y,count)`` where dimensions + ``z`` and ``y`` are optional. + """ + left, right = in_range + return FNeuronAct._init_as_function( + torch.sigmoid, count, input_dim, left, right + ) + + @staticmethod + def hard_sigmoid_init( + count: int, + input_dim: Tuple[int, ...] = (1,), + in_range: Tuple[float, float] = (-1.0, +1.0) + ) -> torch.Tensor: + """ + Weights initialization that creates an activation function that exactly + corresponds to a hard sigmoid. + + NOTE: The minimum recommended input range is from -3.0 to +3.0, + because the hard sigmoid function is defined on exactly this range. + + :param count: number of member functions. + :param input_dim: input data dimensions: + - scalar ``(1,)`` - by default; + - vector ``(x,)``; + - matrix ``(x,y)``; + - multi-channel image: ``(z,x,y)`` where ``z`` is the number of + channels. + :param in_range: the range of input values that should be covered. + :return: initialized tensor of size ``(z,x,y,count)`` where dimensions + ``z`` and ``y`` are optional. + """ + left, right = in_range + return FNeuronAct._init_as_function( + F.hardsigmoid, count, input_dim, left, right + ) + + @staticmethod + def hard_tanh_init( + count: int, + input_dim: Tuple[int, ...] = (1,), + in_range: Tuple[float, float] = (-1.0, +1.0) + ) -> torch.Tensor: + """ + Weights initialization that creates an activation function that exactly + corresponds to a hard tanh. + + NOTE: The minimum recommended input range is from -1.0 to +1.0, + because by default the hard tanh function is defined on exactly this + range in PyTorch. + + :param count: number of member functions. + :param input_dim: input data dimensions: + - scalar ``(1,)`` - by default; + - vector ``(x,)``; + - matrix ``(x,y)``; + - multi-channel image: ``(z,x,y)`` where ``z`` is the number of + channels. + :param in_range: the range of input values that should be covered. + :return: initialized tensor of size ``(z,x,y,count)`` where dimensions + ``z`` and ``y`` are optional. + """ + left, right = in_range + return FNeuronAct._init_as_function( + F.hardsigmoid, count, input_dim, left, right + ) + + @classmethod + def get_init_f_by_name( + cls, init_f_name: str + ) -> Callable[[int, Tuple[int, ...], Tuple[float, float]], torch.Tensor]: + if init_f_name == "Ramp": + fuzzy_init_f = FNeuronAct.ramp_init + elif init_f_name == "Random": + fuzzy_init_f = FNeuronAct.random_init + elif init_f_name == "Constant": + fuzzy_init_f = FNeuronAct.all_hot_init + elif init_f_name == "Tanh": + fuzzy_init_f = FNeuronAct.tanh_init + elif init_f_name == "Sigmoid": + fuzzy_init_f = FNeuronAct.sigmoid_init + elif init_f_name == "HardSigmoid": + fuzzy_init_f = FNeuronAct.hard_sigmoid_init + elif init_f_name == "HardTanh": + fuzzy_init_f = FNeuronAct.hard_tanh_init + else: + raise NotImplementedError( + "Other initialization functions for fuzzy weights are not " + "supported." + ) + + return fuzzy_init_f + + def __init__( + self, left: float, right: float, count: int, + *, init_f: Callable[[int, Tuple[int, ...]], torch.Tensor] = None, + input_dim=(1,) + ): + super().__init__() + self._mfs = torch.nn.ModuleList() + + assert left < right + assert count >= 1 + + if init_f is None: + init_f = self.ramp_init + + self._weights = torch.nn.Parameter( + init_f(count, input_dim, (left, right)) + ) + self._mf_radius = (right - left) / (count + 1) + + self._mfs.append( + LeftRampMembF(self._mf_radius, left) + ) + + for i in range(1, count + 1): + mf_center = left + self._mf_radius * i + mf = TriangularMembF(self._mf_radius, mf_center) + + self._mfs.append(mf) + + self._mfs.append( + RightRampMembF(self._mf_radius, right) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = [mf.forward(x) for mf in self._mfs] + x = torch.stack(x, -1) + x = torch.mul(x, self._weights) + x = torch.sum(x, -1) + return x diff --git a/adaptive_afs/fuzzy/fmf_ramp_left.py b/adaptive_afs/fuzzy/fmf_ramp_left.py new file mode 100644 index 0000000..2f9873d --- /dev/null +++ b/adaptive_afs/fuzzy/fmf_ramp_left.py @@ -0,0 +1,34 @@ +import torch.nn + + +class LeftRampMembF(torch.nn.Module): + """ + ----- + \\ + \\ + \\_______ + """ + def __init__(self, radius: float, center: float): + super().__init__() + + self._radius = radius + self._center = center + self._right = center + radius + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Move to the center + x = x - self._center + + # Scale to the -1.0..1.0 range + x = x / self._radius + + # Clip the value to create a ramp + x = torch.clip(x, min=0.0, max=1.0) + + # Invert the value to create a left ramp + x = 1 - x + + return x + + def __repr__(self): + return "left: {},{}".format(self._center, self._right) diff --git a/adaptive_afs/fuzzy/fmf_ramp_right.py b/adaptive_afs/fuzzy/fmf_ramp_right.py new file mode 100644 index 0000000..e85ba29 --- /dev/null +++ b/adaptive_afs/fuzzy/fmf_ramp_right.py @@ -0,0 +1,34 @@ +import torch.nn + + +class RightRampMembF(torch.nn.Module): + """ + ______ + / + / + _____/ + """ + def __init__(self, radius: float, center: float): + super().__init__() + + self._radius = radius + self._center = center + self._left = center - radius + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Move to the center + x = x - self._center + + # Scale to the -1.0..1.0 range + x = x / self._radius + + # Clip the value to create a ramp + x = torch.clip(x, min=-1.0, max=0.0) + + # Invert the value to create a right ramp + x = 1 + x + + return x + + def __repr__(self): + return "right: {},{}".format(self._left, self._center) diff --git a/adaptive_afs/fuzzy/fmf_triangular.py b/adaptive_afs/fuzzy/fmf_triangular.py new file mode 100644 index 0000000..e289873 --- /dev/null +++ b/adaptive_afs/fuzzy/fmf_triangular.py @@ -0,0 +1,32 @@ +import torch.nn + + +class TriangularMembF(torch.nn.Module): + def __init__(self, radius: float, center: float): + super().__init__() + + self._radius = radius + self._center = center + self._left = center - radius + self._right = center + radius + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Center the inputs + x = x - self._center + + # Scale inputs to the -1.0..+1.0 range + x = x / self._radius + + # Use absolute values + x = torch.absolute(x) + + # Compute output: y = 1 - abs(x) + x = 1 - x + + # Drop outliers (drop negative values after all the operations above) + x = torch.clip(x, 0.0) + + return x + + def __repr__(self): + return "triangle: {},{},{}".format(self._left, self._center, self._right) diff --git a/adaptive_afs/trad/__init__.py b/adaptive_afs/trad/__init__.py new file mode 100644 index 0000000..77574ca --- /dev/null +++ b/adaptive_afs/trad/__init__.py @@ -0,0 +1,2 @@ +from .silu_manual import silu_manual +from .tanh_manual import tanh_manual diff --git a/adaptive_afs/trad/silu_manual.py b/adaptive_afs/trad/silu_manual.py new file mode 100644 index 0000000..cf05562 --- /dev/null +++ b/adaptive_afs/trad/silu_manual.py @@ -0,0 +1,5 @@ +import torch + + +def silu_manual(x): + return x * torch.sigmoid(x) diff --git a/adaptive_afs/trad/tanh_manual.py b/adaptive_afs/trad/tanh_manual.py new file mode 100644 index 0000000..9d5b020 --- /dev/null +++ b/adaptive_afs/trad/tanh_manual.py @@ -0,0 +1,9 @@ +import torch + + +def tanh_manual(x): + return ( + (torch.exp(x) - torch.exp(-x)) + / + (torch.exp(x) + torch.exp(-x)) + ) diff --git a/experiments/__init__.py b/experiments/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/common.py b/experiments/common.py new file mode 100644 index 0000000..e1eadf2 --- /dev/null +++ b/experiments/common.py @@ -0,0 +1,103 @@ +import torch.utils.data +import torchvision + +from typing import Tuple + + +def get_device(dev_name: str = 'gpu') -> torch.device: + if dev_name == 'gpu' and torch.cuda.is_available(): + print("Using GPU computing unit") + torch.cuda.set_device(0) + device = torch.device('cuda:0') + print("Cuda computing capability: {}.{}".format( + *torch.cuda.get_device_capability(device) + )) + else: + print("Using CPU computing unit") + device = torch.device('cpu') + + return device + + +def get_cifar10_dataset( + augment: bool = False +) -> Tuple[torch.utils.data.Dataset, ...]: + if augment: + augments = ( + # as in Keras - each second image is flipped + torchvision.transforms.RandomHorizontalFlip(p=0.5), + # assuming that the values from git.io/JuHV0 were used in + # arXiv 1801.09403 + torchvision.transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)) + ) + else: + augments = () + + train_set = torchvision.datasets.CIFAR10( + root="./data/cifar", + train=True, + download=True, + transform=torchvision.transforms.Compose( + (torchvision.transforms.ToTensor(), *augments) + ) + ) + + test_set = torchvision.datasets.CIFAR10( + root="./data/cifar", + train=False, + download=True, + transform=torchvision.transforms.Compose( + (torchvision.transforms.ToTensor(),) + ) + ) + + return train_set, test_set + + +def get_fmnist_dataset( + augment: bool = False +) -> Tuple[torch.utils.data.Dataset, ...]: + + if augment: + augments = ( + # as in Keras - each second image is flipped + torchvision.transforms.RandomHorizontalFlip(p=0.5), + # assuming that the values from git.io/JuHV0 were used in + # arXiv 1801.09403 + torchvision.transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)) + ) + else: + augments = () + + train_set = torchvision.datasets.FashionMNIST( + root="./data/FashionMNIST", + train=True, + download=True, + transform=torchvision.transforms.Compose( + (torchvision.transforms.ToTensor(), *augments) + ) + ) + + test_set = torchvision.datasets.FashionMNIST( + root="./data/FashionMNIST", + train=False, + download=True, + transform=torchvision.transforms.Compose( + (torchvision.transforms.ToTensor(),) + ) + ) + + return train_set, test_set + + +def get_dataset( + ds_name: str, augment: bool = False +) -> Tuple[torch.utils.data.Dataset, ...]: + if ds_name == "CIFAR-10": + return get_cifar10_dataset(augment) + elif ds_name == "F-MNIST": + return get_fmnist_dataset(augment) + else: + raise NotImplementedError( + "Datasets other than CIFAR-10 and F-MNIST are not supported" + ) diff --git a/experiments/eval_pretrained_base.py b/experiments/eval_pretrained_base.py new file mode 100644 index 0000000..f01164e --- /dev/null +++ b/experiments/eval_pretrained_base.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 + +from eval_pretrained_common import eval_variant + + +def main(): + af_names = ("ReLU", "SiLU") + for af in af_names: + eval_variant("KerasNet", "base", "CIFAR-10", af_name=af, start_ep=100) + + +if __name__ == "__main__": + main() diff --git a/experiments/eval_pretrained_common.py b/experiments/eval_pretrained_common.py new file mode 100644 index 0000000..9b50743 --- /dev/null +++ b/experiments/eval_pretrained_common.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn +import torch.utils.data +import torchinfo + +from experiments.common import get_device, get_dataset +from misc import get_file_name_net, create_net + + +def eval_variant( + net_name: str, net_type: str, ds_name: str, af_name: str, start_ep: int, + *, patched: bool = False +): + batch_size = 64 + rand_seed = 42 + + print( + "Loading pre-trained {} {} network with {}{} activation " + "on the {} dataset after {} epochs.".format( + net_type, net_name, af_name, "" if net_type == "base" else "-like", + ds_name, start_ep + ) + ) + + path_net = get_file_name_net( + net_name, net_type, ds_name, af_name, start_ep, patched + ) + + dev = get_device() + torch.manual_seed(rand_seed) + + _, test_set = get_dataset(ds_name, augment=True) + input_size = (batch_size, 3, 32, 32) + + test_loader = torch.utils.data.DataLoader( + test_set, batch_size=1000, num_workers=4 + ) + + net = create_net(net_name, net_type, ds_name, af_name) + net.to(device=dev) + torchinfo.summary(net, input_size=input_size, device=dev) + + missing, unexpected = net.load_state_dict( + torch.load(path_net), strict=True + ) + + print("Missing keys:", missing) + print("Unexpected keys:", unexpected) + + with torch.no_grad(): + net.eval() + test_total = 0 + test_correct = 0 + + for batch in test_loader: + x = batch[0].to(dev) + y = batch[1].to(dev) + y_hat = net(x) + _, pred = torch.max(y_hat.data, 1) + test_total += y.size(0) + test_correct += torch.eq(pred, y).sum().item() + + print("Epoch: {}. Test set accuracy: {:.2%}".format( + start_ep, test_correct / test_total + )) + diff --git a/experiments/eval_pretrained_patched_ahaf.py b/experiments/eval_pretrained_patched_ahaf.py new file mode 100644 index 0000000..0ea3369 --- /dev/null +++ b/experiments/eval_pretrained_patched_ahaf.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 + +from eval_pretrained_common import eval_variant + + +def main(): + af_names = ("ReLU", "SiLU") + net_name = "KerasNet" + ds_name = "CIFAR-10" + + for af in af_names: + eval_variant( + net_name, "ahaf", ds_name, af_name=af, start_ep=100, + patched=True + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/eval_pretrained_patched_fuzzy.py b/experiments/eval_pretrained_patched_fuzzy.py new file mode 100644 index 0000000..d89a830 --- /dev/null +++ b/experiments/eval_pretrained_patched_fuzzy.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 + +from eval_pretrained_common import eval_variant + + +def main(): + af_names = ("Tanh", "Sigmoid") + net_name = "KerasNet" + ds_name = "CIFAR-10" + + for af in af_names: + eval_variant( + net_name, "fuzzy_ffn", ds_name, af_name=af, start_ep=100, + patched=True + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/patch_pretrained_ahaf.py b/experiments/patch_pretrained_ahaf.py new file mode 100644 index 0000000..d164b81 --- /dev/null +++ b/experiments/patch_pretrained_ahaf.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn +import torch.utils.data +import torchinfo + +from experiments.common import get_device +from misc import get_file_name_net, create_net + + +def patch_variant_as_ahaf(net_name: str, ds_name: str, af_name: str): + batch_size = 64 + rand_seed = 42 + n_epochs_init = 100 + + print("Patching the base {} network with {} activation " + "for {} to use AHAF".format(net_name, af_name, ds_name)) + + path_base = get_file_name_net( + net_name, "base", ds_name, af_name, n_epochs_init, patched=False + ) + + dev = get_device() + torch.manual_seed(rand_seed) + input_size = (batch_size, 3, 32, 32) + + net = create_net(net_name, "ahaf", ds_name, af_name) + net.to(device=dev) + torchinfo.summary(net, input_size=input_size, device=dev) + + missing, unexpected = net.load_state_dict( + torch.load(path_base), strict=False + ) + + print("Missing keys:", missing) + print("Unexpected keys:", unexpected) + + path_ahaf = get_file_name_net( + net_name, "ahaf", ds_name, af_name, n_epochs_init, patched=True + ) + + torch.save( + net.state_dict(), + path_ahaf + ) + + +def main(): + af_names = ("ReLU", "SiLU") + net_name = "KerasNet" + ds_name = "CIFAR-10" + + for af in af_names: + patch_variant_as_ahaf(net_name, ds_name, af) + + +if __name__ == "__main__": + main() diff --git a/experiments/patch_pretrained_fuzzy.py b/experiments/patch_pretrained_fuzzy.py new file mode 100644 index 0000000..b46884d --- /dev/null +++ b/experiments/patch_pretrained_fuzzy.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn +import torch.utils.data +import torchinfo + +from experiments.common import get_device +from misc import get_file_name_net, create_net + + +def patch_variant_as_fuzzy(net_name: str, ds_name: str, af_name: str): + batch_size = 64 + rand_seed = 42 + n_epochs_init = 100 + + print("Patching the base {} network with {} activation " + "for {} to use Fuzzy AF".format(net_name, af_name, ds_name)) + + path_base = get_file_name_net( + net_name, "base", ds_name, af_name, n_epochs_init, patched=False + ) + + dev = get_device() + torch.manual_seed(rand_seed) + input_size = (batch_size, 3, 32, 32) + + net = create_net(net_name, "fuzzy_ffn", ds_name, af_name) + net.to(device=dev) + torchinfo.summary(net, input_size=input_size, device=dev) + + missing, unexpected = net.load_state_dict( + torch.load(path_base), strict=False + ) + + print("Missing keys:", missing) + print("Unexpected keys:", unexpected) + + path_ahaf = get_file_name_net( + net_name, "fuzzy_ffn", ds_name, af_name, n_epochs_init, patched=True + ) + + torch.save( + net.state_dict(), + path_ahaf + ) + + +def main(): + af_names = ("Tanh", "Sigmoid") + net_name = "KerasNet" + ds_name = "CIFAR-10" + + for af in af_names: + patch_variant_as_fuzzy(net_name, ds_name, af) + + +if __name__ == "__main__": + main() diff --git a/experiments/train_common.py b/experiments/train_common.py new file mode 100644 index 0000000..c1534ad --- /dev/null +++ b/experiments/train_common.py @@ -0,0 +1,413 @@ +import json +import warnings + +from typing import Optional, Callable, List, Union, Iterable, TypedDict, Any, \ + Dict + +import torch +import torch.nn +import torch.utils.data +import torchinfo + +try: + import wandb + WANDB_AVAILABLE = True +except ImportError: + WANDB_AVAILABLE = False + +from adaptive_afs import LEAF +from experiments.common import get_device, get_dataset +from misc import get_file_name_checkp, get_file_name_stat,\ + get_file_name_train_args + +from nns_aaf import KerasNetAaf, LeNetAaf +from misc import RunningStat, ProgressRecorder, create_net + + +AafNetwork = Union[KerasNetAaf, LeNetAaf] + + +class CheckPoint(TypedDict): + net: Dict[str, Any] + opts: List[Dict[str, Any]] + + +def _net_train_aaf(net: AafNetwork): + for p in net.parameters(): + p.requires_grad = False + + for p in net.activation_params: + p.requires_grad = True + + +def _net_train_non_aaf(net: AafNetwork): + for p in net.parameters(): + p.requires_grad = True + + for p in net.activation_params: + p.requires_grad = False + + +def _net_train_noop(net: AafNetwork): + pass + + +def _net_split_leaf_params(net: AafNetwork): + aaf_param_ids = set() + leaf_p24_params = [] + aaf_rest_params = [] + generic_params = [] + + for act in net.activations: + if isinstance(act, LEAF): + p24_params = (act.p2, act.p4) + p13_params = (act.p1, act.p3) + + leaf_p24_params.extend(p24_params) + aaf_rest_params.extend(p13_params) + aaf_param_ids.update( + (id(p) for p in act.parameters()) + ) + elif isinstance(act, torch.nn.Module): + aaf_rest_params.extend(act.parameters()) + aaf_param_ids.update( + (id(p) for p in act.parameters()) + ) + + for p in net.parameters(): + if id(p) not in aaf_param_ids: + generic_params.append(p) + + return leaf_p24_params, aaf_rest_params, generic_params + + +def get_opt_by_name( + opt_name: str, base_lr: float, + net_params: Iterable[Union[torch.nn.Parameter, Dict]] +) -> torch.optim.Optimizer: + if opt_name == 'rmsprop': + opt = torch.optim.RMSprop( + params=net_params, + lr=base_lr, + alpha=0.9, # default Keras + momentum=0.0, # default Keras + eps=1e-7, # default Keras + centered=False # default Keras + ) + elif opt_name == 'adam': + opt = torch.optim.Adam( + params=net_params, + lr=base_lr, + ) + else: + raise NotImplementedError("Only ADAM and RMSProp supported") + + return opt + + +def train_variant( + net_name: str, net_type: str, + ds_name: str, af_name: str, end_epoch: int = 100, *, + start_epoch: int = 0, patched: bool = False, + af_name_cnn: Optional[str] = None, + param_freezer: Optional[Callable[[AafNetwork], None]] = None, + save_as_fine_tuned: bool = False, + dspu4: bool = False, p24sl: bool = False, opt_name: str = 'rmsprop', + seed: int = 42, bs: int = 64, dev_name: str = 'gpu', + patch_base: bool = False, wandb_enable: bool = False +): + """ + Initialize, load and train the model for the specified number of epochs. + Saves the trained network, the optimizer state and the statistics to the + `./runs` directory. + + :param net_name: name of the model - "KerasNet" or "LeNet-5" + :param net_type: type of the model - "base", "ahaf", "leaf", "fuzzy_ffn" + :param ds_name: name of the dataset - "CIFAR-10" or "F-MNIST" + :param af_name: the initial activation function form name - + "ReLU", "SiLU", "Tanh", "Sigmnoid" and so on + :param end_epoch: stop the training at this epoch + :param start_epoch: start the training at this epoch + :param patched: indicates to load the "patched" model that was initially + trained with the base activation and then upgraded to an + adaptive alternative + :param af_name_cnn: specify the different initial activation function form + for the fully connected layers of the network + :param param_freezer: a callback function to freeze some parameters in the + network before starting the training + :param save_as_fine_tuned: saves the files with the "tuned_" suffix + :param dspu4: set to `True` to use the 2SPU-4 training procedure + :param p24sl: set to `True` to decrease LR for LEAF params p2 and p4 + :param opt_name: set the optimizer: 'adam' or 'rmsprop' + :param seed: the initial value for RNG + :param bs: the training data block size + :param dev_name: training executor device + :param patch_base: perform in-place patching of the base network + :param wandb_enable: enable logging to Weights and Biases + :return: None + """ + if param_freezer and dspu4: + raise ValueError( + "The parameter freezing function and the 2SPU-4 procedure can't be " + "enabled and used at the same time" + ) + + if wandb_enable and not WANDB_AVAILABLE: + raise ValueError( + "The wandb library is not available. Install the library or disable" + "logging to Weights and Biases in the arguments" + ) + + batch_size = bs + rand_seed = seed + + dev = get_device(dev_name) + torch.manual_seed(rand_seed) + torch.use_deterministic_algorithms(mode=True) + + train_set, test_set = get_dataset(ds_name, augment=True) + input_size = (batch_size, *train_set[0][0].shape) + + train_loader = torch.utils.data.DataLoader( + train_set, batch_size=batch_size, num_workers=4 + ) + test_loader = torch.utils.data.DataLoader( + test_set, batch_size=1000, num_workers=4 + ) + + net = create_net( + net_name, net_type, ds_name, af_name, af_name_cnn=af_name_cnn + ) + + error_fn = torch.nn.CrossEntropyLoss() + net.to(device=dev) + torchinfo.summary(net, input_size=input_size, device=dev) + + if opt_name == 'rmsprop': + base_lr = 1e-4 + p24lr = base_lr / 10 + elif opt_name == 'adam': + base_lr = 0.001 + p24lr = base_lr / 1000 + else: + raise NotImplementedError("Only ADAM and RMSProp supported") + + if not net_type.startswith("leaf"): + # Ignore on everything except LEAF + p24sl = False + + net_params_leaf_p24, net_params_aaf_rest, net_params_non_aaf = _net_split_leaf_params(net) + opt_params_non_aaf = [ + {'params': net_params_non_aaf} + ] + + if p24sl: + print(f"Using a custom learning rate of {p24lr} for LEAF params " + f"p2 and p4") + opt_params_aaf = [ + {'params': net_params_aaf_rest}, + {'params': net_params_leaf_p24, 'lr': p24lr} + ] + else: + opt_params_aaf = [ + {'params': [*net_params_aaf_rest, *net_params_leaf_p24]} + ] + + opt_params_sets: List[List[Dict]] + + if dspu4: + # Create two different optimizers: one for AAF, one for non-AAF params + opt_params_sets = [ + opt_params_aaf, net_params_non_aaf + ] + else: + # Create a single optimizer for AAF and non-AAF params + opt_params_sets = [ + [*opt_params_non_aaf, *opt_params_aaf] + ] + + opts: List[torch.optim.Optimizer] + opts = [get_opt_by_name(opt_name, base_lr, ps) for ps in opt_params_sets] + + if start_epoch > 0: + net_type_to_load = "base" if patch_base else net_type + strict_load = not patch_base + dspu4_to_load = False if patch_base else dspu4 + p24sl_to_load = False if patch_base else p24sl + + path_checkp = get_file_name_checkp( + net_name, net_type_to_load, ds_name, af_name, start_epoch, patched, + af_name_cnn=af_name_cnn, dspu4=dspu4_to_load, + p24sl=p24sl_to_load, opt_name=opt_name + ) + + checkp: CheckPoint + checkp = torch.load(path_checkp) + net.load_state_dict(checkp['net'], strict=strict_load) + + if ('opts' in checkp) and (not patch_base): + opt_states = checkp['opts'] + assert len(opts) == len(opt_states) + for i in range(len(opt_states)): + opts[i].load_state_dict(opt_states[i]) + else: + warnings.warn( + "The old optimizer state is not available{}. Initialized the " + "optimizer from scratch.".format( + " after patching" if (patched or patch_base) else "" + ) + ) + + print( + "Training the {} {} network with {} in CNN and {} in FFN " + "on the {} dataset for {} epochs total using the {} training procedure " + "and the {} optimizer." + "".format( + net_type, net_name, af_name if af_name_cnn is None else af_name_cnn, + af_name, ds_name, end_epoch, "2SPU-4" if dspu4 else "standard", + opt_name + ) + ) + + # Freeze the parameters if such hook is defined. + if param_freezer is not None: + param_freezer(net) + + mb_param_freezers = [] # type: List[Callable[[AafNetwork], None]] + if dspu4: + mb_param_freezers.append(_net_train_aaf) + mb_param_freezers.append(_net_train_non_aaf) + else: + mb_param_freezers.append(_net_train_noop) + assert len(opts) == len(mb_param_freezers) + + progress = ProgressRecorder() + + # TODO: Refactor, pass TypedDict as the function argument + args_content = { + "net_name": net_name, + "net_type": net_type, + "ds_name": ds_name, + "af_name": af_name, + "end_epoch": end_epoch, + "start_epoch": start_epoch, + "patched": patched, + "af_name_cnn": af_name_cnn, + #"param_freezer": param_freezer, + "save_as_fine_tuned": save_as_fine_tuned, + "dspu4": dspu4, + "p24sl": p24sl, + "opt_name": opt_name, + "seed": seed, + "bs": bs, + "dev_name": dev_name, + "patch_base": patch_base + } + + args_path = get_file_name_train_args( + net_name, net_type, ds_name, af_name, end_epoch, + patched or patch_base, + fine_tuned=save_as_fine_tuned, af_name_cnn=af_name_cnn, + dspu4=dspu4, p24sl=p24sl, opt_name=opt_name + ) + + if wandb_enable: + wandb_run_name = args_path.lstrip("runs/") + wandb_run_name = wandb_run_name.rstrip("_args.json") + wandb.init( + project='leaf-cnn', reinit=True, name=wandb_run_name, + config=args_content, group=f"{net_name}_{ds_name}_{af_name}" + ) + wandb.watch(net, criterion=error_fn, log_freq=390, log='all') + + with open(args_path, 'w') as f: + json.dump(args_content, f, indent=2) + + for epoch in range(start_epoch, end_epoch): + net.train() + loss_stat = RunningStat() + progress.start_ep() + + for mb in train_loader: + x, y = mb[0].to(dev), mb[1].to(dev) + last_loss_in_mb: float = -1.0 + + for mbf, opt in zip(mb_param_freezers, opts): + mbf(net) + + # The wandb logger does not support `net.forward()` + y_hat = net(x) + loss = error_fn(y_hat, target=y) + last_loss_in_mb = loss.item() + + # Update parameters + opt.zero_grad() + loss.backward() + opt.step() + + loss_stat.push(last_loss_in_mb) + + progress.end_ep() + net.eval() + + with torch.no_grad(): + test_total = 0 + test_correct = 0 + + for batch in test_loader: + x = batch[0].to(dev) + y = batch[1].to(dev) + y_hat = net(x) + test_loss = error_fn(y_hat, target=y) + _, pred = torch.max(y_hat.data, 1) + test_total += y.size(0) + test_correct += torch.eq(pred, y).sum().item() + + test_acc = test_correct / test_total + + print("Train set loss stat: m={}, var={}".format( + loss_stat.mean, loss_stat.variance + )) + print("Epoch: {}. Test set accuracy: {:.2%}. Test set loss: {:.2}".format( + epoch, test_acc, test_loss + )) + if wandb_enable: + wandb.log({ + 'train_loss': loss_stat.mean, + 'test_loss': test_loss, + 'test_acc': test_acc} + ) + progress.push_ep( + epoch, loss_stat.mean, loss_stat.variance, test_acc, + lr=' '.join( + [str(pg["lr"]) for opt in opts for pg in opt.param_groups] + ) + ) + + progress.save_as_csv( + get_file_name_stat( + net_name, net_type, ds_name, af_name, end_epoch, + patched or patch_base, + fine_tuned=save_as_fine_tuned, af_name_cnn=af_name_cnn, + dspu4=dspu4, p24sl=p24sl, opt_name=opt_name + ) + ) + + checkp: CheckPoint + checkp = { + 'net': net.state_dict(), + 'opts': [opt.state_dict() for opt in opts] + } + + torch.save( + checkp, + get_file_name_checkp( + net_name, net_type, ds_name, af_name, end_epoch, + patched or patch_base, + fine_tuned=save_as_fine_tuned, af_name_cnn=af_name_cnn, + dspu4=dspu4, p24sl=p24sl, opt_name=opt_name + ) + ) + + if wandb_enable: + wandb.finish() diff --git a/experiments/train_individual.py b/experiments/train_individual.py new file mode 100644 index 0000000..425d03e --- /dev/null +++ b/experiments/train_individual.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +import itertools +import argparse +import pprint + +from train_common import train_variant +from misc import NetInfo + + +def main(): + parser = argparse.ArgumentParser( + prog='train_new_simple' + ) + parser.add_argument('af_type') + parser.add_argument('--opt', default='rmsprop') + parser.add_argument('--seed', default=42) + parser.add_argument('--bs', default=128, type=int) + parser.add_argument('--p24sl', action='store_true') + parser.add_argument('--dspu4', action='store_true') + parser.add_argument('--dev', default='gpu') + parser.add_argument('--net', default='all') + parser.add_argument('--ds', default='all') + parser.add_argument('--start_ep', default=0, type=int) + parser.add_argument('--end_ep', default=100, type=int) + parser.add_argument('--patch_base', action='store_true') + parser.add_argument('--acts', default='all_lus') + parser.add_argument('--wandb', action='store_true') + args = parser.parse_args() + + net_names = ["LeNet-5", "KerasNet"] if args.net == 'all' else [args.net] + ds_names = ["F-MNIST", "CIFAR-10"] if args.ds == 'all' else [args.ds] + + if args.acts == 'all': + act_names = ['ReLU', 'SiLU', 'Tanh', 'Sigmoid'] + elif args.acts == 'all_lus': + act_names = ['ReLU', 'SiLU'] + elif args.acts == 'all_bfs': + act_names = ['Tanh', 'Sigmoid'] + else: + act_names = [args.acts] + + net_ds_combinations = itertools.product(net_names, ds_names, act_names) + start_ep = args.start_ep + end_ep = args.end_ep + nets = [] + opt = args.opt + seed = args.seed + bs = args.bs + p24sl = args.p24sl + dspu4 = args.dspu4 + dev_name = args.dev + patch_base = args.patch_base + wandb = args.wandb + + for n, ds, act in net_ds_combinations: + net = NetInfo(n, args.af_type, ds, act, end_ep, dspu4=dspu4, opt_name=opt, p24sl=p24sl) + nets.append(net) + + print("Training the following combinations:") + pprint.pprint(nets) + + for net in nets: + train_variant( + net.net_name, net.net_type, net.ds_name, af_name=net.af_name, + end_epoch=net.epoch, dspu4=net.dspu4, p24sl=net.p24sl, + opt_name=net.opt_name, seed=seed, bs=bs, dev_name=dev_name, + start_epoch=start_ep, patch_base=patch_base, wandb_enable=wandb + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/train_new_ahaf.py b/experiments/train_new_ahaf.py new file mode 100644 index 0000000..16f7e73 --- /dev/null +++ b/experiments/train_new_ahaf.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +import itertools + +from train_common import train_variant +from misc import NetInfo + + +def main(): + net_names = ["LeNet-5", "KerasNet"] + ds_names = ["F-MNIST", "CIFAR-10"] + af_names = ["ReLU", "SiLU"] + combinations = itertools.product(net_names, ds_names, af_names) + epochs = 100 + nets = [] + + for n, ds, af in combinations: + nets_nds = [ + NetInfo(n, "ahaf", ds, af, epochs, dspu4=False), + ] + nets.extend(nets_nds) + + for net in nets: + train_variant( + net.net_name, net.net_type, net.ds_name, af_name=net.af_name, + end_epoch=net.epoch, dspu4=net.dspu4 + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/train_new_all_bfs_ffn.py b/experiments/train_new_all_bfs_ffn.py new file mode 100644 index 0000000..101f1a2 --- /dev/null +++ b/experiments/train_new_all_bfs_ffn.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +import itertools + +from train_common import train_variant +from misc import NetInfo + + +def main(): + net_names = ["LeNet-5", "KerasNet"] + ds_names = ["F-MNIST", "CIFAR-10"] + net_ds_combinations = itertools.product(net_names, ds_names) + epochs = 100 + nets = [] + + for n, ds in net_ds_combinations: + nets_nds = [ + NetInfo(n, "base", ds, "Tanh", epochs, dspu4=False), + NetInfo(n, "leaf", ds, "Tanh", epochs, dspu4=False), + NetInfo(n, "leaf", ds, "Tanh", epochs, dspu4=True), + ] + nets.extend(nets_nds) + + for net in nets: + train_variant( + net.net_name, net.net_type, net.ds_name, af_name=net.af_name, + end_epoch=net.epoch, dspu4=net.dspu4 + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/train_new_all_lus.py b/experiments/train_new_all_lus.py new file mode 100644 index 0000000..5cea90b --- /dev/null +++ b/experiments/train_new_all_lus.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +import itertools + +from train_common import train_variant +from misc import NetInfo + + +def main(): + net_names = ["LeNet-5", "KerasNet"] + ds_names = ["F-MNIST", "CIFAR-10"] + net_ds_combinations = itertools.product(net_names, ds_names) + epochs = 100 + nets = [] + opt = "rmsprop" + #opt = "adam" + + for n, ds in net_ds_combinations: + nets_nds = [ + NetInfo(n, "base", ds, "ReLU", epochs, dspu4=False, opt_name=opt), + NetInfo(n, "base", ds, "SiLU", epochs, dspu4=False, opt_name=opt), + NetInfo(n, "ahaf", ds, "ReLU", epochs, dspu4=False, opt_name=opt), + NetInfo(n, "ahaf", ds, "SiLU", epochs, dspu4=False, opt_name=opt), + NetInfo(n, "ahaf", ds, "ReLU", epochs, dspu4=True, opt_name=opt), + NetInfo(n, "ahaf", ds, "SiLU", epochs, dspu4=True, opt_name=opt), + NetInfo(n, "leaf", ds, "ReLU", epochs, dspu4=False, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "SiLU", epochs, dspu4=False, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "ReLU", epochs, dspu4=True, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "SiLU", epochs, dspu4=True, p24sl=True, opt_name=opt), + ] + nets.extend(nets_nds) + + for net in nets: + train_variant( + net.net_name, net.net_type, net.ds_name, af_name=net.af_name, + end_epoch=net.epoch, dspu4=net.dspu4, p24sl=net.p24sl, + opt_name=net.opt_name + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/train_new_all_lus_ffn.py b/experiments/train_new_all_lus_ffn.py new file mode 100644 index 0000000..bbd143f --- /dev/null +++ b/experiments/train_new_all_lus_ffn.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +import itertools + +from train_common import train_variant +from misc import NetInfo + + +def main(): + net_names = ["LeNet-5", "KerasNet"] + ds_names = ["F-MNIST", "CIFAR-10"] + af_names = ["ReLU", "SiLU"] + combinations = itertools.product(net_names, ds_names, af_names) + epochs = 100 + nets = [] + + for n, ds, af in combinations: + nets_nds = [ + NetInfo(n, "leaf_ffn", ds, af, epochs, dspu4=False), + ] + nets.extend(nets_nds) + + for net in nets: + train_variant( + net.net_name, net.net_type, net.ds_name, af_name=net.af_name, + end_epoch=net.epoch, dspu4=net.dspu4 + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/train_new_all_no_fuzzy_bfs.py b/experiments/train_new_all_no_fuzzy_bfs.py new file mode 100644 index 0000000..101f1a2 --- /dev/null +++ b/experiments/train_new_all_no_fuzzy_bfs.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +import itertools + +from train_common import train_variant +from misc import NetInfo + + +def main(): + net_names = ["LeNet-5", "KerasNet"] + ds_names = ["F-MNIST", "CIFAR-10"] + net_ds_combinations = itertools.product(net_names, ds_names) + epochs = 100 + nets = [] + + for n, ds in net_ds_combinations: + nets_nds = [ + NetInfo(n, "base", ds, "Tanh", epochs, dspu4=False), + NetInfo(n, "leaf", ds, "Tanh", epochs, dspu4=False), + NetInfo(n, "leaf", ds, "Tanh", epochs, dspu4=True), + ] + nets.extend(nets_nds) + + for net in nets: + train_variant( + net.net_name, net.net_type, net.ds_name, af_name=net.af_name, + end_epoch=net.epoch, dspu4=net.dspu4 + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/train_new_base.py b/experiments/train_new_base.py new file mode 100644 index 0000000..0a0b196 --- /dev/null +++ b/experiments/train_new_base.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +import itertools + +from train_common import train_variant +from misc import NetInfo + + +def main(): + net_names = ["LeNet-5", "KerasNet"] + ds_names = ["F-MNIST", "CIFAR-10"] + af_names = ["ReLU", "SiLU", "Tanh", "Sigmoid"] + combinations = itertools.product(net_names, ds_names, af_names) + epochs = 100 + nets = [] + + for n, ds, af in combinations: + nets_nds = [ + NetInfo(n, "base", ds, af, epochs, dspu4=False), + ] + nets.extend(nets_nds) + + for net in nets: + train_variant( + net.net_name, net.net_type, net.ds_name, af_name=net.af_name, + end_epoch=net.epoch, dspu4=net.dspu4 + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/train_new_fuzzy.py b/experiments/train_new_fuzzy.py new file mode 100644 index 0000000..a88a2a5 --- /dev/null +++ b/experiments/train_new_fuzzy.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +import itertools + +from train_common import train_variant +from misc import NetInfo + + +def main(): + net_names = ["LeNet-5", "KerasNet"] + ds_names = ["F-MNIST", "CIFAR-10"] + af_names = ["Tanh", "Sigmoid"] + combinations = itertools.product(net_names, ds_names, af_names) + epochs = 100 + nets = [] + + for n, ds, af in combinations: + nets_nds = [ + NetInfo(n, "fuzzy_ffn", ds, af, epochs), + ] + nets.extend(nets_nds) + + for net in nets: + train_variant( + net.net_name, net.net_type, net.ds_name, af_name=net.af_name, + end_epoch=net.epoch, dspu4=net.dspu4 + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/train_new_leaf.py b/experiments/train_new_leaf.py new file mode 100644 index 0000000..155a60c --- /dev/null +++ b/experiments/train_new_leaf.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +import itertools + +from train_common import train_variant +from misc import NetInfo + + +def main(): + net_names = ["LeNet-5", "KerasNet"] + ds_names = ["F-MNIST", "CIFAR-10"] + af_names = ["ReLU", "SiLU"] + combinations = itertools.product(net_names, ds_names, af_names) + epochs = 100 + nets = [] + + for n, ds, af in combinations: + nets_nds = [ + NetInfo(n, "leaf", ds, af, epochs, dspu4=False, p24sl=False), + NetInfo(n, "leaf", ds, af, epochs, dspu4=True, p24sl=False), + ] + nets.extend(nets_nds) + + for net in nets: + train_variant( + net.net_name, net.net_type, net.ds_name, af_name=net.af_name, + end_epoch=net.epoch, dspu4=net.dspu4, p24sl=net.p24sl + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/tune_patched_ahaf_only.py b/experiments/tune_patched_ahaf_only.py new file mode 100644 index 0000000..14fc3d0 --- /dev/null +++ b/experiments/tune_patched_ahaf_only.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +from train_common import train_variant +from tune_patched_common import freeze_non_af + + +def main(): + af_names = ("ReLU", "SiLU") + start_ep = 100 + end_ep = start_ep + 50 + + for af in af_names: + train_variant( + "KerasNet", "ahaf", "CIFAR-10", af_name=af, + start_epoch=start_ep, end_epoch=end_ep, patched=True, + param_freezer=freeze_non_af, save_as_fine_tuned=True + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/tune_patched_common.py b/experiments/tune_patched_common.py new file mode 100644 index 0000000..1fc2c3a --- /dev/null +++ b/experiments/tune_patched_common.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +from typing import Iterable, Union +import torch.nn + +from nns_aaf import KerasNetAaf, LeNetAaf + + +def param_freeze_status_to_symbol(param: torch.nn.Parameter) -> str: + if param.requires_grad: + return '🟢' + else: + return '🔴' + + +def print_params_freeze_status(params: Iterable[torch.nn.Parameter]): + status = [ + param_freeze_status_to_symbol(p) for p in params + ] + print(status) + + +def freeze_non_af(net: Union[KerasNetAaf, LeNetAaf]): + for p in net.parameters(): + p.requires_grad = False + + for afp in net.activation_params: + afp.requires_grad = True + + print_params_freeze_status(net.parameters()) diff --git a/misc/__init__.py b/misc/__init__.py new file mode 100644 index 0000000..1082cb8 --- /dev/null +++ b/misc/__init__.py @@ -0,0 +1,7 @@ +from .net_info import NetInfo +from .running_stat import RunningStat +from .progress_recorder import ProgressRecorder, ProgressElement +from .create_net import create_net, AF_FUZZY_DEFAULT_INTERVAL +from .file_names import get_file_name_net, get_file_name_opt,\ + get_file_name_stat, get_file_name_stat_img, get_file_name_stat_table,\ + get_file_name_aaf_img, get_file_name_checkp, get_file_name_train_args diff --git a/misc/create_net.py b/misc/create_net.py new file mode 100644 index 0000000..dfcd483 --- /dev/null +++ b/misc/create_net.py @@ -0,0 +1,67 @@ +from typing import Optional, Union + +from nns_aaf import LeNetAaf, KerasNetAaf, AfDefinition + + +AF_FUZZY_DEFAULT_INTERVAL = AfDefinition.AfInterval( + start=-4.0, end=+4.0, n_segments=16 +) + + +def create_net( + net_name: str, net_type: str, ds_name: str, af_name: str, + *, af_name_cnn: Optional[str] = None +) -> Union[LeNetAaf, KerasNetAaf]: + + if af_name_cnn is None: + af_name_cnn = af_name + + af_name_ffn = af_name + + if net_type == "base": + af_type_cnn = AfDefinition.AfType.TRAD + af_type_ffn = AfDefinition.AfType.TRAD + af_interval_ffn = None + elif net_type == "ahaf": + af_type_cnn = AfDefinition.AfType.ADA_AHAF + af_type_ffn = AfDefinition.AfType.ADA_AHAF + af_interval_ffn = None + elif net_type == "ahaf_ffn": + af_type_cnn = AfDefinition.AfType.TRAD + af_type_ffn = AfDefinition.AfType.ADA_AHAF + af_interval_ffn = None + elif net_type == "leaf": + af_type_cnn = AfDefinition.AfType.ADA_LEAF + af_type_ffn = AfDefinition.AfType.ADA_LEAF + af_interval_ffn = None + elif net_type == "leaf_ffn": + af_type_cnn = AfDefinition.AfType.TRAD + af_type_ffn = AfDefinition.AfType.ADA_LEAF + af_interval_ffn = None + elif net_type == "fuzzy_ffn": + af_type_cnn = AfDefinition.AfType.TRAD + af_type_ffn = AfDefinition.AfType.ADA_FUZZ + af_interval_ffn = AF_FUZZY_DEFAULT_INTERVAL + else: + raise ValueError("Network type is not supported") + + cnn_af = AfDefinition( + af_base=af_name_cnn, af_type=af_type_cnn + ) + + ffn_af = AfDefinition( + af_base=af_name_ffn, af_type=af_type_ffn, + af_interval=af_interval_ffn + ) + + if ds_name == "CIFAR-10": + ds_name = "CIFAR10" + + if net_name == "KerasNet": + net = KerasNetAaf(flavor=ds_name, af_conv=cnn_af, af_fc=ffn_af) + elif net_name == "LeNet-5": + net = LeNetAaf(flavor=ds_name, af_conv=cnn_af, af_fc=ffn_af) + else: + raise NotImplementedError("Only LeNet-5 and KerasNet are supported") + + return net diff --git a/misc/file_names.py b/misc/file_names.py new file mode 100644 index 0000000..dee6e2b --- /dev/null +++ b/misc/file_names.py @@ -0,0 +1,181 @@ +from typing import Optional + + +def _normalize_net_name(name: str) -> str: + if name == "KerasNet": + name = "kerasnet" + elif name == "LeNet-5": + name = "lenet5" + + return name + + +def _normalize_ds_name(name: str) -> str: + if name == "CIFAR-10": + name = "cifar10" + elif name == "F-MNIST": + name = "f-mnist" + + return name + + +def _format_net_variant(net_name: str, ds_name: str) -> str: + if ds_name: + net_name = f"{net_name}_{ds_name}" + + return net_name + + +def _format_net_af(af_name: str, af_name_cnn: Optional[str]) -> str: + if af_name_cnn and af_name != af_name_cnn: + af_name = "{}_{}".format(af_name, af_name_cnn) + + return af_name + + +def _get_file_name_net_base( + file_type: str, + net_name: str, net_type: str, ds_name: str, af_name: str, + epoch: int, patched: bool, fine_tuned: bool, + af_name_cnn: Optional[str], + dspu4: bool = False, p24sl: bool = False, opt_name: str = 'rmsprop' +) -> str: + patched_str = "patched_" if patched else "" + fine_tuned_str = "tuned_" if fine_tuned else "" + + net_name = _normalize_net_name(net_name) + ds_name = _normalize_ds_name(ds_name) + net_variant = _format_net_variant(net_name, ds_name) + opt_part = f"{opt_name}_" if opt_name else "" + train_alg = "dspu4_" if dspu4 else "" + train_lr = "p24sl_" if p24sl else "" + + if file_type == "stat": + extension = "csv" + elif file_type == "aaf_img": + extension = "svg" + elif file_type == "args": + extension = "json" + else: + extension = "bin" + + af_name = _format_net_af(af_name, af_name_cnn) + + file_path = f"runs/{net_variant}_{af_name}_" \ + f"{patched_str}{fine_tuned_str}{net_type}_" \ + f"{opt_part}" \ + f"{train_alg}{train_lr}{epoch}ep_{file_type}.{extension}" + + return file_path + + +def get_file_name_net( + net_name: str, net_type: str, ds_name: str, af_name: str, epoch: int, + patched: bool = False, fine_tuned: bool = False, + af_name_cnn: Optional[str] = None, + dspu4: bool = False, p24sl: bool = False, opt_name: str = 'rmsprop' +) -> str: + return _get_file_name_net_base( + "net", net_name, net_type, ds_name, af_name, epoch, + patched, fine_tuned, af_name_cnn, dspu4, p24sl, opt_name + ) + + +def get_file_name_opt( + net_name: str, net_type: str, ds_name: str, af_name: str, epoch: int, + patched: bool = False, fine_tuned: bool = False, + af_name_cnn: Optional[str] = None, + dspu4: bool = False, p24sl: bool = False, opt_name: str = 'rmsprop' +) -> str: + return _get_file_name_net_base( + "opt", net_name, net_type, ds_name, af_name, epoch, + patched, fine_tuned, af_name_cnn, dspu4, p24sl, opt_name + ) + + +def get_file_name_stat( + net_name: str, net_type: str, ds_name: str, af_name: str, epoch: int, + patched: bool = False, fine_tuned: bool = False, + af_name_cnn: Optional[str] = None, + dspu4: bool = False, p24sl: bool = False, opt_name: str = 'rmsprop' +) -> str: + return _get_file_name_net_base( + "stat", net_name, net_type, ds_name, af_name, epoch, + patched, fine_tuned, af_name_cnn, dspu4, p24sl, opt_name + ) + + +def _get_file_name_net_summary( + file_type: str, + net_name: str, ds_name, net_group: str, epoch: int, + patched: bool = False, fine_tuned: bool = False +) -> str: + if file_type == "stat_img": + extension = "svg" + else: + extension = "csv" + + patched_str = "_patched" if patched else "" + fine_tuned_str = "_tuned" if fine_tuned else "" + net_name = _normalize_net_name(net_name) + ds_name = _normalize_ds_name(ds_name) + net_variant = _format_net_variant(net_name, ds_name) + + file_path = f"runs/{net_variant}{patched_str}{fine_tuned_str}_" \ + f"summary_{net_group}_{epoch}ep_{file_type}.{extension}" + + return file_path + + +def get_file_name_stat_img( + net_name: str, ds_name, net_group: str, epoch: int, + patched: bool = False, fine_tuned: bool = False +) -> str: + return _get_file_name_net_summary( + "stat_img", net_name, ds_name, net_group, epoch, patched, fine_tuned + ) + + +def get_file_name_stat_table( + net_name: str, ds_name, net_group: str, epoch: int, + patched: bool = False, fine_tuned: bool = False +) -> str: + return _get_file_name_net_summary( + "stat_table", net_name, ds_name, net_group, epoch, patched, fine_tuned + ) + + +def get_file_name_aaf_img( + net_name: str, net_type: str, ds_name: str, af_name: str, epoch: int, + patched: bool = False, fine_tuned: bool = False, + af_name_cnn: Optional[str] = None, + dspu4: bool = False, p24sl: bool = False, opt_name: str = 'rmsprop' +) -> str: + return _get_file_name_net_base( + "aaf_img", net_name, net_type, ds_name, af_name, epoch, + patched, fine_tuned, af_name_cnn, dspu4, p24sl, opt_name + ) + + +def get_file_name_checkp( + net_name: str, net_type: str, ds_name: str, af_name: str, epoch: int, + patched: bool = False, fine_tuned: bool = False, + af_name_cnn: Optional[str] = None, + dspu4: bool = False, p24sl: bool = False, opt_name: str = 'rmsprop' +) -> str: + return _get_file_name_net_base( + "checkp", net_name, net_type, ds_name, af_name, epoch, + patched, fine_tuned, af_name_cnn, dspu4, p24sl, opt_name + ) + + +def get_file_name_train_args( + net_name: str, net_type: str, ds_name: str, af_name: str, epoch: int, + patched: bool = False, fine_tuned: bool = False, + af_name_cnn: Optional[str] = None, + dspu4: bool = False, p24sl: bool = False, opt_name: str = 'rmsprop' +) -> str: + return _get_file_name_net_base( + "args", net_name, net_type, ds_name, af_name, epoch, + patched, fine_tuned, af_name_cnn, dspu4, p24sl, opt_name + ) diff --git a/misc/net_info.py b/misc/net_info.py new file mode 100644 index 0000000..45aceba --- /dev/null +++ b/misc/net_info.py @@ -0,0 +1,15 @@ +from typing import NamedTuple, Optional + + +class NetInfo(NamedTuple): + net_name: str + net_type: str + ds_name: str + af_name: str + epoch: int + patched: bool = False, + fine_tuned: bool = False + af_name_cnn: Optional[str] = None + dspu4: bool = False + p24sl: bool = False + opt_name: str = "rmsprop" diff --git a/misc/progress_recorder.py b/misc/progress_recorder.py new file mode 100644 index 0000000..8bc0349 --- /dev/null +++ b/misc/progress_recorder.py @@ -0,0 +1,56 @@ +import time + +from typing import TypedDict, List, Mapping +from csv import DictWriter + + +class ProgressElement(TypedDict): + epoch: int + train_loss_mean: float + train_loss_var: float + test_acc: float + lr: str + duration: float + + +class ProgressRecorder(object): + def __init__(self): + self._els = [] # type: List[Mapping] + self._ep_start = None + self._ep_delta = None + + def start_ep(self): + self._ep_start = time.time() + + def end_ep(self): + ep_end = time.time() + + if self._ep_start is None: + self._ep_delta = None + else: + self._ep_delta = ep_end - self._ep_start + + self._ep_start = None + + def push_ep( + self, epoch: int, train_loss_mean: float, train_loss_var: float, + test_acc: float, lr: str + ): + self._els.append( + ProgressElement( + epoch=epoch, + train_loss_mean=train_loss_mean, + train_loss_var=train_loss_var, + test_acc=test_acc, lr=lr, + duration=self._ep_delta + ) + ) + self._ep_delta = None + + def save_as_csv(self, path: str): + fields = ProgressElement.__annotations__.keys() + + with open(path, 'w') as f: + writer = DictWriter(f, fields) + writer.writeheader() + writer.writerows(self._els) diff --git a/misc/running_stat.py b/misc/running_stat.py new file mode 100644 index 0000000..dd4e38a --- /dev/null +++ b/misc/running_stat.py @@ -0,0 +1,43 @@ +import math + + +# Based on https://www.johndcook.com/blog/standard_deviation/ +class RunningStat(object): + def __init__(self): + self._num_points = 0 # type: int + self._old_mean = None + self._old_variance = None + self._new_mean = None + self._new_variance = None + + def clear(self): + self._num_points = 0 + + def push(self, x: float): + self._num_points += 1 + + if self._num_points == 1: + self._old_mean = self._new_mean = x + self._old_variance = 0.0 + else: + self._new_mean = self._old_mean + (x - self._old_mean) / self._num_points + self._new_variance = self._old_variance + (x - self._old_mean) * (x - self._new_mean) + + self._old_mean = self._new_mean + self._old_variance = self._new_variance + + @property + def num_datapoints(self) -> int: + return self._num_points + + @property + def mean(self) -> float: + return self._new_mean if self._num_points > 0 else 0 + + @property + def variance(self) -> float: + return self._new_variance / (self._num_points - 1) if self._num_points > 1 else 0.0 + + @property + def stddev(self) -> float: + return math.sqrt(self.variance) diff --git a/nns_aaf/__init__.py b/nns_aaf/__init__.py new file mode 100644 index 0000000..54bda31 --- /dev/null +++ b/nns_aaf/__init__.py @@ -0,0 +1,2 @@ +from .kerasnet_aaf import KerasNetAaf, AfDefinition +from .lenet_aaf import LeNetAaf diff --git a/nns_aaf/cnn_aaf_base.py b/nns_aaf/cnn_aaf_base.py new file mode 100644 index 0000000..f128628 --- /dev/null +++ b/nns_aaf/cnn_aaf_base.py @@ -0,0 +1,74 @@ +from typing import Optional, List + +import torch.nn + +from adaptive_afs import AfDefinition + + +class CnnAafBase(torch.nn.Module): + def __init__( + self, *, flavor='MNIST', + af_conv: Optional[AfDefinition] = None, + af_fc: Optional[AfDefinition] = None + ): + super().__init__() + + if flavor == 'MNIST' or flavor == 'F-MNIST': + self._init_mnist_dims() + elif flavor == 'CIFAR10': + self._init_cifar_dims() + else: + raise NotImplemented("Other flavors of LeNet-5 are not supported") + + if af_conv is None: + # Use ReLU in the convolutional layers by default + af_conv = AfDefinition( + af_base="ReLU", af_type=AfDefinition.AfType.TRAD, + af_interval=None + ) + + self._af_def_conv = af_conv + + if af_fc is None: + # Use ReLU in the fully connected layers by default + af_fc = AfDefinition( + af_base="ReLU", af_type=AfDefinition.AfType.TRAD, + af_interval=None + ) + + self._af_def_fc = af_fc + + self._sequence = [] + + def _init_mnist_dims(self): + raise NotImplementedError( + "The MNIST support shall be implemented on the child class level" + ) + + def _init_cifar_dims(self): + raise NotImplementedError( + "The CIFAR-10 support shall be implemented on the child class level" + ) + + def forward(self, x): + for mod in self._sequence: + x = mod(x) + + return x + + @property + def activations(self): + raise NotImplementedError( + "The list of activations shall be specified on the child class " + "level" + ) + + @property + def activation_params(self) -> List[torch.nn.Parameter]: + params = [] + + for act in self.activations: + if isinstance(act, torch.nn.Module): + params.extend(act.parameters()) + + return params diff --git a/nns_aaf/kerasnet_aaf.py b/nns_aaf/kerasnet_aaf.py new file mode 100644 index 0000000..3215d0b --- /dev/null +++ b/nns_aaf/kerasnet_aaf.py @@ -0,0 +1,146 @@ +from typing import Optional, List + +import torch.nn + +from adaptive_afs import AfDefinition, af_build +from .cnn_aaf_base import CnnAafBase + + +class KerasNetAaf(CnnAafBase): + """ + KerasNet - CNN implementation evaluated in arXiv 1801.09403 **but** with + optional support of adaptive activation functions (AHAF, LEAF, F-Neuron + Activation). The model is based on the example CNN implementation from + Keras 1.x: git.io/JuHV0. + + Architecture: + + - 2D convolution 32 x (3,3) with (1,1) padding + - Conv Activation Function + - 2D convolution 32 x (3,3) w/o padding + - Conv Activation Function + - max pooling (2,2) + - dropout 25% + - 2D convolution 64 x (3,3) with (1,1) padding + - Conv Activation Function + - 2D convolution 64 x (3,3) w/o padding + - FFN Activation Function + - max pooling (2,2) + - dropout 25% + - fully connected, out_features = 512 + - FFN Activation Function + - dropout 50% + - fully connected, out_features = 10 + - softmax activation + + """ + + def __init__( + self, *, flavor='MNIST', + af_conv: Optional[AfDefinition] = None, + af_fc: Optional[AfDefinition] = None + ): + super().__init__(flavor=flavor, af_conv=af_conv, af_fc=af_fc) + + self.conv1 = torch.nn.Conv2d( + in_channels=self._image_channels, out_channels=32, + kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True + ) + self.act1 = af_build( + self._af_def_conv, + in_dims=(self.conv1.out_channels, *self._act1_img_dims) + ) + + self.conv2 = torch.nn.Conv2d( + in_channels=self.conv1.out_channels, out_channels=32, + kernel_size=(3, 3), stride=(1, 1), padding=(0, 0), bias=True + ) + self.act2 = af_build( + self._af_def_conv, + in_dims=(self.conv2.out_channels, *self._act2_img_dims) + ) + + self.pool3 = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) + self.drop3 = torch.nn.Dropout2d(p=0.25) + + self.conv4 = torch.nn.Conv2d( + in_channels=self.conv2.out_channels, out_channels=64, + kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True + ) + self.act4 = af_build( + self._af_def_conv, + in_dims=(self.conv4.out_channels, *self._act4_img_dims) + ) + + self.conv5 = torch.nn.Conv2d( + in_channels=self.conv4.out_channels, out_channels=64, + kernel_size=(3, 3), stride=(1, 1), padding=(0, 0), bias=True + ) + self.act5 = af_build( + self._af_def_conv, + in_dims=(self.conv5.out_channels, *self._act5_img_dims), + ) + + self.pool6 = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) + self.drop6 = torch.nn.Dropout2d(p=0.25) + + self._flatter = torch.nn.Flatten(start_dim=1, end_dim=-1) + + self.fc7 = torch.nn.Linear( + in_features=self._fc7_in_features, + out_features=self._fc8_out_features, bias=True + ) + + self.act7 = af_build(self._af_def_fc, in_dims=(self._fc8_out_features,)) + + self.drop7 = torch.nn.Dropout(p=0.2) + + self.fc8 = torch.nn.Linear( + in_features=self._fc8_out_features, out_features=10, bias=True + ) + + # softmax is embedded in pytorch's loss function + + self._sequence = [ + self.conv1, self.act1, self.conv2, self.act2, + self.pool3, self.drop3, + self.conv4, self.act4, self.conv5, self.act5, + self.pool6, self.drop6, + self._flatter, + self.fc7, self.act7, self.drop7, + self.fc8 + ] + + def _init_mnist_dims(self): + self._image_channels = 1 + self._fc7_in_features = 5 * 5 * 64 + self._fc8_out_features = 512 + self._act1_img_dims = (28, 28) + self._act2_img_dims = (26, 26) + self._act4_img_dims = (13, 13) + self._act5_img_dims = (11, 11) + + def _init_cifar_dims(self): + self._image_channels = 3 + self._fc7_in_features = 6 * 6 * 64 + self._fc8_out_features = 512 + self._act1_img_dims = (32, 32) + self._act2_img_dims = (30, 30) + self._act4_img_dims = (15, 15) + self._act5_img_dims = (13, 13) + + @property + def activations(self): + return [ + self.act1, self.act2, self.act4, self.act5, self.act7 + ] + + @property + def activation_params(self) -> List[torch.nn.Parameter]: + params = [] + + for act in self.activations: + if isinstance(act, torch.nn.Module): + params.extend(act.parameters()) + + return params diff --git a/nns_aaf/lenet_aaf.py b/nns_aaf/lenet_aaf.py new file mode 100644 index 0000000..e353655 --- /dev/null +++ b/nns_aaf/lenet_aaf.py @@ -0,0 +1,79 @@ +from typing import Optional, List + +import torch.nn + +from adaptive_afs import AfDefinition, af_build +from .cnn_aaf_base import CnnAafBase + + +class LeNetAaf(CnnAafBase): + """ + Implementation of LeNet-5 with the support of AHAF, LEAF and the + F-Neuron Activation. The network structure follows arXiv 1801.09403. + """ + + def __init__( + self, *, flavor='MNIST', + af_conv: Optional[AfDefinition] = None, + af_fc: Optional[AfDefinition] = None + ): + super().__init__(flavor=flavor, af_conv=af_conv, af_fc=af_fc) + + self.conv1 = torch.nn.Conv2d( + in_channels=self._image_channels, out_channels=20, + kernel_size=(5, 5), stride=(1, 1), padding=(0, 0), bias=False + ) + self.act1 = af_build( + self._af_def_conv, + in_dims=(self.conv1.out_channels, *self._act1_img_dims) + ) + self.pool1 = torch.nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv2 = torch.nn.Conv2d( + in_channels=20, out_channels=50, kernel_size=(5, 5), + stride=(1, 1), padding=(0, 0), bias=False + ) + self.act2 = af_build( + self._af_def_conv, + in_dims=(self.conv2.out_channels, *self._act2_img_dims) + ) + self.pool2 = torch.nn.MaxPool2d(kernel_size=(2, 2)) + + self._flatter = torch.nn.Flatten(start_dim=1, end_dim=-1) + + self.fc3 = torch.nn.Linear( + in_features=self._fc3_in_features, out_features=500, bias=True + ) + self.act3 = af_build(self._af_def_fc, in_dims=(self.fc3.out_features,)) + + self.fc4 = torch.nn.Linear( + in_features=500, out_features=10, bias=False + ) + # SoftMax is embedded into the Cross Entropy loss function. + + self._sequence = [ + self.conv1, self.act1, self.pool1, + self.conv2, self.act2, self.pool2, + self._flatter, + self.fc3, self.act3, + self.fc4 + ] + + def _init_mnist_dims(self): + self._image_channels = 1 + self._fc3_in_features = 4 * 4 * 50 + self._act1_img_dims = (24, 24) + self._act2_img_dims = (8, 8) + + def _init_cifar_dims(self): + self._image_channels = 3 + self._fc3_in_features = 5 * 5 * 50 + self._act1_img_dims = (28, 28) + self._act2_img_dims = (10, 10) + + @property + def activations(self): + return [ + self.act1, self.act2, self.act3 + ] + diff --git a/post_experiment/__init__.py b/post_experiment/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/post_experiment/show_aaf_form.py b/post_experiment/show_aaf_form.py new file mode 100644 index 0000000..f724aa7 --- /dev/null +++ b/post_experiment/show_aaf_form.py @@ -0,0 +1,313 @@ +import itertools +import warnings +from typing import Sequence, Tuple, Optional + +import torch +import matplotlib.pyplot as plt +from cycler import cycler + +from adaptive_afs import FNeuronAct +from adaptive_afs import af_build, AfDefinition +from misc import get_file_name_checkp, get_file_name_aaf_img, create_net,\ + AF_FUZZY_DEFAULT_INTERVAL, NetInfo + + +def get_random_idxs(max_i, cnt=10) -> Sequence[int]: + return [ + int(torch.randint(size=(1,), low=0, high=max_i)) for _ in range(cnt) + ] + + +def random_selection(params, idxs): + return [params[i] for i in idxs] + + +def visualize_af_base(af_name: str, x, subfig): + base_def = AfDefinition( + af_type=AfDefinition.AfType.TRAD, + af_base=af_name + ) + af = af_build(base_def) + y = af(x) + x_view = x.cpu().numpy() + y_view = y.cpu().numpy() + subfig.plot(x_view, y_view) + + +def visualize_af_ahaf(rho1, rho3, x, subfig): + y = (rho1 * x) * torch.sigmoid(rho3 * x) + x_view = x.cpu().numpy() + y_view = y.cpu().numpy() + subfig.plot(x_view, y_view) + + +def visualize_af_leaf(params, x, subfig): + rho1, rho2, rho3, rho4 = params + y = (rho1 * x + rho2) * torch.sigmoid(rho3 * x) + rho4 + x_view = x.cpu().numpy() + y_view = y.cpu().numpy() + subfig.plot(x_view, y_view) + + +def visualize_af_fuzzy( + fuzzy_def: AfDefinition, x: torch.Tensor, weights: torch.nn.Parameter, + subfig +): + def _restore_weights( + count: int, input_dim: Tuple[int, ...], + in_range: Tuple[float, float] = (-1.0, +1.0) + ) -> torch.Tensor: + return weights.data + + af = FNeuronAct( + left=fuzzy_def.interval.start, + right=fuzzy_def.interval.end, + count=weights.size(dim=-1) - 2, + init_f=_restore_weights + ) + y = af(x) + + x_view = x.cpu().numpy() + y_view = y.cpu().numpy() + + subfig.plot(x_view, y_view) + + +def visualize_afs_ahaf_by_params( + params: Sequence[torch.nn.Parameter], fig, rows, show_subtitles, + reference_af_name: Optional[str] = None +): + num_neurons = len(params) // 2 + start_index = max(0, num_neurons - rows) + cols = 5 + + x = torch.arange(start=-10, end=4.0, step=0.1, + device=params[0].device) + + gs = plt.GridSpec(rows, cols) + + for i in range(rows): + param_idx = start_index + i + all_gamma = params[param_idx * 2].view(-1) + all_beta = params[param_idx * 2 + 1].view(-1) + sel = get_random_idxs(max_i=len(all_gamma), cnt=cols) + sel_gamma = random_selection(all_gamma, sel) + sel_beta = random_selection(all_beta, sel) + + for j in range(cols): + subfig = fig.add_subplot(gs[i, j]) + if show_subtitles: + subfig.set_title("L{} F{}".format(i, sel[j])) + subfig.title.set_size(10) + + gamma = sel_gamma[j] + beta = sel_beta[j] + + if reference_af_name is not None: + visualize_af_base(reference_af_name, x, subfig=subfig) + + visualize_af_ahaf(beta, gamma, x, subfig=subfig) + + +def visualize_afs_leaf_by_params( + params: Sequence[torch.nn.Parameter], fig, rows, show_subtitles, + reference_af_name: Optional[str] = None +): + num_neurons = len(params) // 4 + start_index = max(0, num_neurons - rows) + cols = 5 + + x = torch.arange(start=-10, end=4.0, step=0.1, + device=params[0].device) + + gs = plt.GridSpec(rows, cols) + + for i in range(rows): + param_idx = start_index + i + all_p1 = params[param_idx * 4].view(-1) + all_p2 = params[param_idx * 4 + 1].view(-1) + all_p3 = params[param_idx * 4 + 2].view(-1) + all_p4 = params[param_idx * 4 + 3].view(-1) + + sel = get_random_idxs(max_i=len(all_p3), cnt=cols) + sel_p1 = random_selection(all_p1, sel) + sel_p2 = random_selection(all_p2, sel) + sel_p3 = random_selection(all_p3, sel) + sel_p4 = random_selection(all_p4, sel) + + for j in range(cols): + subfig = fig.add_subplot(gs[i, j]) + if show_subtitles: + subfig.set_title("L{} F{}".format(i, sel[j])) + subfig.title.set_size(10) + + instance = sel_p1[j], sel_p2[j], sel_p3[j], sel_p4[j] + + if reference_af_name is not None: + visualize_af_base(reference_af_name, x, subfig=subfig) + + visualize_af_leaf(instance, x, subfig=subfig) + + +def visualize_afs_fuzzy_by_params( + params: Sequence[torch.nn.Parameter], fig, rows, show_subtitles, + reference_af_name: Optional[str] = None +): + num_neurons = len(params) // 1 + start_index = max(0, num_neurons - rows) + cols = 5 + + # WARNING: Keep the definition updated with create_net + fuzzy_def = AfDefinition( + af_base="DoNotCare", # does not matter for pre-init functions + af_type=AfDefinition.AfType.ADA_FUZZ, + af_interval=AF_FUZZY_DEFAULT_INTERVAL + ) + + # Longer range to visualize Sigmoid and Tanh + x = torch.arange(start=-3.5, end=+3.5, step=0.1, device=params[0].device) + + gs = plt.GridSpec(rows, cols) + + for i in range(rows): + param_idx = start_index + i + all_mfs_weights = params[param_idx] + sel = get_random_idxs(max_i=len(all_mfs_weights), cnt=cols) + sel_mfs_weights = random_selection(all_mfs_weights, sel) + + for j in range(cols): + subfig = fig.add_subplot(gs[i, j]) + if show_subtitles: + subfig.set_title("L{} F{}".format(i, sel[j])) + subfig.title.set_size(10) + + weights = sel_mfs_weights[j] + + if reference_af_name is not None: + visualize_af_base(reference_af_name, x, subfig=subfig) + + visualize_af_fuzzy(fuzzy_def, x, weights, subfig=subfig) + + +def _is_same_af_name(net_info: NetInfo) -> bool: + if net_info.af_name_cnn is None: + return True + + if net_info.af_name_cnn == net_info.af_name: + return True + + if net_info.net_type == "fuzzy_ffn": + return True + + # leaves AHAF with different AFs in CNN and FFN + return False + + +def _get_reference_af_name(net_info: NetInfo) -> Optional[str]: + if _is_same_af_name(net_info): + return net_info.af_name + else: + warnings.warn("The network contains different activations in the CNN " + "and FFN layers. Unable to visualize the base function.") + return None + + +def visualize_afs(net_info: NetInfo, max_rows: int = 2, bw=False, + show_reference: bool = False): + torch.manual_seed(seed=128) + + if net_info.net_type == "ahaf" or net_info.net_type == "ahaf_ffn": + params_per_neuron = 2 # constant for AHAF + visualizer = visualize_afs_ahaf_by_params + elif net_info.net_type == "leaf" or net_info.net_type == "leaf_ffn": + params_per_neuron = 4 # constant for LEAF + visualizer = visualize_afs_leaf_by_params + elif net_info.net_type == "fuzzy_ffn": + params_per_neuron = 1 # constant for Fuzzy AF, 1 param set per neuron + visualizer = visualize_afs_fuzzy_by_params + else: + raise ValueError("Network type is not supported") + + img_path = get_file_name_aaf_img(*net_info) + checkp_path = get_file_name_checkp(*net_info) + checkp = torch.load(checkp_path) + net = create_net( + net_info.net_name, net_info.net_type, + net_info.ds_name, net_info.af_name, + af_name_cnn=net_info.af_name_cnn + ) + net.load_state_dict(checkp['net']) + + af_params = net.activation_params + num_neurons = len(af_params) // params_per_neuron + + show_subtitles = True + + if max_rows is None: + rows = num_neurons + else: + rows = min(max_rows, num_neurons) + + height = 1.0 * rows + + if bw: + # TODO: Set locally for this figure, not globally + monochrome = cycler('color', ['black', 'grey']) + plt.rcParams['axes.prop_cycle'] = monochrome + + tight_layout = {'pad': 0.35} + fig = plt.figure(tight_layout=tight_layout, figsize=(5, height)) + + if show_reference: + ref_af_name = _get_reference_af_name(net_info) + else: + ref_af_name = None + + with torch.no_grad(): + visualizer( + af_params, fig, rows, show_subtitles, + reference_af_name=ref_af_name + ) + + plt.savefig(img_path, dpi=300, format='svg') + plt.close(fig) + + +def main(): + max_rows = 5 + + net_names = ["LeNet-5", "KerasNet"] + ds_names = ["F-MNIST", "CIFAR-10"] + net_ds_combinations = itertools.product(net_names, ds_names) + opt = "adam" + nets = [] + + for n, ds in net_ds_combinations: + nets_nds = [ + NetInfo(n, "ahaf", ds, "ReLU", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=False, opt_name=opt), + NetInfo(n, "ahaf", ds, "SiLU", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=False, opt_name=opt), + + NetInfo(n, "leaf", ds, "ReLU", 100, patched=False, fine_tuned=False, dspu4=False, p24sl=False, opt_name=opt), + NetInfo(n, "leaf", ds, "SiLU", 100, patched=False, fine_tuned=False, dspu4=False, p24sl=False, opt_name=opt), + NetInfo(n, "leaf", ds, "ReLU", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "SiLU", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=True, opt_name=opt), + + NetInfo(n, "leaf", ds, "Tanh", 100, patched=False, fine_tuned=False, dspu4=False, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "Sigmoid", 100, patched=False, fine_tuned=False, dspu4=False, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "Tanh", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "Sigmoid", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=True, opt_name=opt), + ] + nets.extend(nets_nds) + + for net in nets: + try: + visualize_afs(net, max_rows, bw=True, show_reference=True) + except FileNotFoundError: + continue + except Exception as e: + print("Exception: {}, skipped".format(e)) + continue + + +if __name__ == "__main__": + main() diff --git a/post_experiment/show_af_diff.py b/post_experiment/show_af_diff.py new file mode 100644 index 0000000..016ed4c --- /dev/null +++ b/post_experiment/show_af_diff.py @@ -0,0 +1,157 @@ +import contextlib +import itertools +from typing import Optional, NamedTuple + +import torch +import torch.nn.functional +from matplotlib import pyplot as plt +from cycler import cycler + +from adaptive_afs import af_build, AfDefinition +from adaptive_afs.trad import tanh_manual, silu_manual + + +class ErrorHolder(NamedTuple): + min_error: float + max_error: float + + +def estimate_error( + orig_fn, drv_fn, left=-4.0, right=+4.0, img_path: Optional[str] = None +) -> ErrorHolder: + n_points = 100000 + + range_len = right - left + step = range_len / n_points + eps = step / 100 + + with torch.no_grad(): + x = torch.arange(start=left, end=right + eps, step=step) + + y = orig_fn(x) + y_hat = drv_fn(x) + + errors = torch.square(y - y_hat) + max_error = torch.max(errors) + min_error = torch.min(errors) + + x_view = x.cpu().numpy() + err_view = errors.cpu().numpy() + + monochrome = cycler('color', ['black']) + plt.rcParams['axes.prop_cycle'] = monochrome + + plt.xlabel("Input, x") + plt.ylabel("Error, E=Δ^2") + plt.title(f"min={min_error.item()},max={max_error.item()}") + + plt.plot(x_view, err_view) + + if img_path is None: + plt.show() + else: + plt.savefig(img_path, dpi=300, format='svg') + plt.close() + + return ErrorHolder(min_error.item(), max_error.item()) + + +def estimate_err_manual_silu( + left=-4.0, right=+4.0, img_path: Optional[str] = None +) -> ErrorHolder: + orig_fn = torch.nn.functional.silu + drv_fn = silu_manual + + return estimate_error(orig_fn, drv_fn, left, right, img_path) + + +def estimate_err_manual_tanh( + left=-4.0, right=+4.0, img_path: Optional[str] = None +) -> ErrorHolder: + orig_fn = torch.tanh + drv_fn = tanh_manual + + return estimate_error(orig_fn, drv_fn, left, right, img_path) + + +def estimate_err_aaf( + af_def: AfDefinition, + left=-4.0, right=+4.0, img_path: Optional[str] = None +) -> ErrorHolder: + orig_fn = af_build( + AfDefinition(af_def.af_base, AfDefinition.AfType.TRAD) + ) + drv_fn = af_build(af_def) + + return estimate_error(orig_fn, drv_fn, left, right, img_path) + + +def estimate_all(dev_name: str, prec_name: str): + estimate_err_manual_silu( + -15.0, +15.0, + img_path=f"runs/af_diff_manual_silu_{dev_name}_{prec_name}.svg" + ) + estimate_err_manual_tanh( + -15.0, +15.0, + img_path=f"runs/af_diff_manual_tanh_{dev_name}_{prec_name}.svg" + ) + + af_defs_fuzz = [ + AfDefinition(af_base="Tanh", af_type=AfDefinition.AfType.ADA_FUZZ, + af_interval=AfDefinition.AfInterval(-12.0, +12.0, 768)) + ] + + for ff in af_defs_fuzz: + img_path = f"runs/af_diff_fuzz_{ff.af_base}_{dev_name}_{prec_name}.svg" + estimate_err_aaf(ff, ff.interval.start, ff.interval.end, img_path) + + af_names_ahaf = ["ReLU", "SiLU"] + + for afn in af_names_ahaf: + img_path = f"runs/af_diff_ahaf_{afn}_{dev_name}_{prec_name}.svg" + af_def = AfDefinition(af_base=afn, af_type=AfDefinition.AfType.ADA_AHAF) + estimate_err_aaf(af_def, -15.0, +15.0, img_path) + + af_names_leaf = ["ReLU", "SiLU", "Tanh", "Sigmoid"] + + for afn in af_names_leaf: + img_path = f"runs/af_diff_leaf_{afn}_{dev_name}_{prec_name}.svg" + af_def = AfDefinition(af_base=afn, af_type=AfDefinition.AfType.ADA_LEAF) + estimate_err_aaf(af_def, -15.0, +15.0, img_path) + + +@contextlib.contextmanager +def precision(name: str): + prev_dtype = torch.get_default_dtype() + + if name == "float16": + new_dtype = torch.float16 + elif name == "float64": + new_dtype = torch.float64 + else: + new_dtype = torch.float32 + + torch.set_default_dtype(new_dtype) + + try: + yield + finally: + torch.set_default_dtype(prev_dtype) + + +def main(): + devices = ["cpu", "cuda"] + precisions = ["float16", "float32", "float64"] + + for dev, prec in itertools.product(devices, precisions): + if dev == "cpu" and prec == "float16": + # Skip, not implemented in PyTorch + continue + + with torch.device(dev): + with precision(prec): + estimate_all(dev, prec) + + +if __name__ == "__main__": + main() diff --git a/post_experiment/show_progress_charts.py b/post_experiment/show_progress_charts.py new file mode 100644 index 0000000..0c6d27f --- /dev/null +++ b/post_experiment/show_progress_charts.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 + +import itertools +import csv +import matplotlib.pyplot as plt + +from typing import Sequence, Tuple, List, Union +from cycler import cycler + +from misc import ProgressElement +from misc import NetInfo +from misc import get_file_name_stat, get_file_name_stat_img + + +def load_results(file_path: str) -> List[ProgressElement]: + with open(file_path, 'r') as f: + reader = csv.DictReader(f) + return list(reader) + + +def get_legend_long(net_info: Union[Tuple, NetInfo]) -> str: + if not isinstance(net_info, NetInfo): + net_info = NetInfo(*net_info) + + af_name_ffn = net_info.af_name + + if net_info.af_name_cnn is None: + af_name_cnn = af_name_ffn + else: + af_name_cnn = net_info.af_name_cnn + + if net_info.net_type == "base": + legend = "{} CNN, {} FFN".format(af_name_cnn, af_name_ffn) + elif net_info.net_type == "ahaf": + legend = "{}-like AHAF CNN, {}-like AHAF FFN".format( + af_name_cnn, af_name_ffn + ) + elif net_info.net_type == "ahaf_ffn": + legend = "{} CNN, {}-like AHAF FFN".format( + af_name_cnn, af_name_ffn + ) + elif net_info.net_type == "leaf": + legend = "{}-like LEAF CNN, {}-like LEAF FFN".format( + af_name_cnn, af_name_ffn + ) + elif net_info.net_type == "leaf_ffn": + legend = "{} CNN, {}-like LEAF FFN".format( + af_name_cnn, af_name_ffn + ) + elif net_info.net_type == "fuzzy_ffn": + legend = "{} CNN, {}-like Fuzzy FFN".format( + af_name_cnn, af_name_ffn + ) + else: + raise ValueError("Network type is not supported") + + if net_info.fine_tuned: + legend = legend + ", fine-tuned" + + if net_info.dspu4: + legend = legend + ", 2SPU-4" + + if net_info.p24sl: + legend = legend + ", slow LEAF {p2,p4} update" + + return legend + + +def get_short_af_name(orig: str) -> str: + if orig == "Tanh": + return "tanh" + elif orig == "Sigmoid": + return "σ-fn" + else: + return orig + + +def get_legend_short( + net_info: Union[Tuple, NetInfo], omit_af_names: bool = False, + include_opt: bool = False +) -> str: + if not isinstance(net_info, NetInfo): + net_info = NetInfo(*net_info) + + af_name_ffn = net_info.af_name + + if net_info.af_name_cnn is None: + af_name_cnn = af_name_ffn + else: + af_name_cnn = net_info.af_name_cnn + + af_name_cnn = get_short_af_name(af_name_cnn) + af_name_ffn = get_short_af_name(af_name_ffn) + + if net_info.net_type == "base": + net_type_str = "Base" + elif net_info.net_type == "ahaf": + net_type_str = "AHAF" + elif net_info.net_type == "ahaf_ffn": + net_type_str = "AHAF FFN" + elif net_info.net_type == "leaf": + net_type_str = "LEAF" + elif net_info.net_type == "leaf_ffn": + net_type_str = "LEAF FFN" + elif net_info.net_type == "fuzzy_ffn": + net_type_str = "Fuzzy" + else: + raise ValueError("Network type is not supported") + + if omit_af_names: + legend = net_type_str + else: + legend = f"{net_type_str}, {af_name_cnn}, {af_name_ffn}" + + if net_info.opt_name == 'adam': + opt_name_str = 'ADAM' + elif net_info.opt_name == 'rmsprop': + opt_name_str = 'RMSprop' + else: + raise ValueError("Optimizer is not supported") + + if include_opt: + legend = f"{legend}, {opt_name_str}" + + if net_info.fine_tuned: + legend = legend + ", tuned" + + if net_info.dspu4: + legend = legend + ", DSPT" + + if net_info.p24sl: + legend = legend + ", P24Sl" + + return legend + + +def analyze_network( + net_info: Tuple, omit_af_names: bool = False, include_opt: bool = False +): + file_path = get_file_name_stat(*net_info) + results = load_results(file_path) + base_legend = get_legend_short(net_info, omit_af_names, include_opt) + + acc = [] + loss = [] + + for r in results: + acc.append(float(r["test_acc"]) * 100.0) + loss.append(float(r["train_loss_mean"])) + + return base_legend, acc, loss + + +def plot_networks( + fig, nets: Sequence[Union[Tuple, NetInfo]], + bw=False, omit_af_names=False, include_opt=False +) -> bool: + acc_legends = [] + loss_legends = [] + + monochrome = ( + cycler('linestyle', ['-', '--', ':', '-.']) + * cycler('color', ['black', 'grey']) + * cycler('marker', ['None']) + ) + + gs = plt.GridSpec(1, 2) + + acc_fig = fig.add_subplot(gs[0, 0]) + #acc_loc = plticker.LinearLocator(numticks=10) + #acc_fig.yaxis.set_major_locator(acc_loc) + acc_fig.set_xlabel('epoch') + acc_fig.set_ylabel('test accuracy, %') + acc_fig.grid() + if bw: + acc_fig.set_prop_cycle(monochrome) + + loss_fig = fig.add_subplot(gs[0, 1]) + #loss_loc = plticker.LinearLocator(numticks=10) + #loss_fig.yaxis.set_major_locator(loss_loc) + loss_fig.set_xlabel('epoch') + loss_fig.set_ylabel('training loss') + loss_fig.grid() + if bw: + loss_fig.set_prop_cycle(monochrome) + + net_processed = 0 + + for net in nets: + try: + base_legend, acc, loss = analyze_network( + net, omit_af_names, include_opt + ) + except FileNotFoundError: + continue + except Exception as e: + print("Exception: {}, skipped".format(e)) + continue + + net_processed += 1 + n_epochs = len(acc) + end_ep = net.epoch + start_ep = end_ep - n_epochs + + x = tuple(range(start_ep, end_ep)) + + acc_legends.append( + base_legend + ) + loss_legends.append( + base_legend + ) + acc_fig.plot(x, acc) + loss_fig.plot(x, loss) + + acc_fig.legend(acc_legends) + loss_fig.legend(loss_legends) + + return net_processed > 0 + + +def visualize( + net_name: str, ds_name: str, net_group: str, + nets: Sequence[Union[Tuple, NetInfo]], base_title=None, + bw: bool = False, omit_af_names: bool = False, + include_opt: bool = False +): + fig = plt.figure(tight_layout=True, figsize=(6, 3)) + if base_title is not None: + title = "{}, test accuracy and training loss".format(base_title) + fig.suptitle(title) + + success = plot_networks(fig, nets, bw, omit_af_names, include_opt) + if success: + #plt.show() + plt.savefig(get_file_name_stat_img(net_name, ds_name, net_group, + nets[0].epoch, nets[0].patched, + nets[0].fine_tuned)) + + plt.close(fig) + + +def main(): + net_names = ["LeNet-5", "KerasNet"] + ds_names = ["F-MNIST", "CIFAR-10"] + net_ds_combinations = itertools.product(net_names, ds_names) + opt = "adam" + + omit_all_captions = True + + for n, ds in net_ds_combinations: + relu_comparison = [ + NetInfo(n, "base", ds, "ReLU", 100, patched=False, fine_tuned=False, opt_name=opt), + NetInfo(n, "ahaf", ds, "ReLU", 100, patched=False, fine_tuned=False, opt_name=opt), + NetInfo(n, "ahaf", ds, "ReLU", 100, patched=False, fine_tuned=False, dspu4=True, opt_name=opt), + NetInfo(n, "leaf", ds, "ReLU", 100, patched=False, fine_tuned=False, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "ReLU", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=True, opt_name=opt), + ] + + visualize( + n, ds, f"relu_{opt}", relu_comparison, + None if omit_all_captions else f"{n} on {ds} - ReLU-like AFs", + bw=False, omit_af_names=True + ) + + leaf_relu_unstable_cmp = [ + NetInfo(n, "leaf", ds, "ReLU", 100, patched=False, fine_tuned=False, opt_name='adam'), + NetInfo(n, "ahaf", ds, "ReLU", 100, patched=False, fine_tuned=False, opt_name='rmsprop'), + NetInfo(n, "leaf", ds, "ReLU", 100, patched=False, fine_tuned=False, opt_name='rmsprop'), + NetInfo(n, "leaf", ds, "ReLU", 100, patched=False, fine_tuned=False, p24sl=True, opt_name='rmsprop'), + NetInfo(n, "leaf", ds, "ReLU", 100, patched=False, fine_tuned=False, p24sl=True, opt_name='adam'), + ] + + visualize( + n, ds, f"leaf_relu_stability", leaf_relu_unstable_cmp, + None if omit_all_captions else "Learning rate effect on ReLU-like LEAFs", + bw=True, omit_af_names=True, include_opt=True + ) + + silu_comparison = [ + NetInfo(n, "base", ds, "SiLU", 100, patched=False, fine_tuned=False, opt_name=opt), + NetInfo(n, "ahaf", ds, "SiLU", 100, patched=False, fine_tuned=False, opt_name=opt), + NetInfo(n, "ahaf", ds, "SiLU", 100, patched=False, fine_tuned=False, dspu4=True, opt_name=opt), + NetInfo(n, "leaf", ds, "SiLU", 100, patched=False, fine_tuned=False, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "SiLU", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=True, opt_name=opt), + ] + + visualize( + n, ds, f"silu_{opt}", silu_comparison, + None if omit_all_captions else f"{n} on {ds} - SiLU-like AFs", + bw=False, omit_af_names=True + ) + + leaf_silu_unstable_cmp = [ + NetInfo(n, "leaf", ds, "SiLU", 100, patched=False, fine_tuned=False, opt_name='adam'), + NetInfo(n, "ahaf", ds, "SiLU", 100, patched=False, fine_tuned=False, opt_name='rmsprop'), + NetInfo(n, "leaf", ds, "SiLU", 100, patched=False, fine_tuned=False, opt_name='rmsprop'), + NetInfo(n, "leaf", ds, "SiLU", 100, patched=False, fine_tuned=False, p24sl=True, opt_name='rmsprop'), + NetInfo(n, "leaf", ds, "SiLU", 100, patched=False, fine_tuned=False, p24sl=True, opt_name='adam'), + ] + + visualize( + n, ds, f"leaf_silu_stability", leaf_silu_unstable_cmp, + None if omit_all_captions else "Learning rate effect on SiLU-like LEAFs", + bw=True, omit_af_names=True, include_opt=True + ) + + tanh_comparison = [ + NetInfo(n, "base", ds, "Tanh", 100, patched=False, fine_tuned=False, opt_name=opt), + NetInfo(n, "leaf", ds, "Tanh", 100, patched=False, fine_tuned=False, dspu4=False, p24sl=False, opt_name=opt), + #NetInfo(n, "leaf", ds, "Tanh", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=False, opt_name=opt), + NetInfo(n, "leaf", ds, "Tanh", 100, patched=False, fine_tuned=False, dspu4=False, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "Tanh", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=True, opt_name=opt), + ] + + visualize( + n, ds, f"tanh_{opt}", tanh_comparison, + None if omit_all_captions else f"{n} on {ds} - Tanh-like AFs", + bw=True, omit_af_names=True + ) + + sigmoid_comparison = [ + NetInfo(n, "base", ds, "Sigmoid", 100, patched=False, fine_tuned=False, opt_name=opt), + NetInfo(n, "leaf", ds, "Sigmoid", 100, patched=False, fine_tuned=False, dspu4=False, p24sl=False, opt_name=opt), + #NetInfo(n, "leaf", ds, "Sigmoid", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=False, opt_name=opt), + NetInfo(n, "leaf", ds, "Sigmoid", 100, patched=False, fine_tuned=False, dspu4=False, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "Sigmoid", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=True, opt_name=opt), + ] + + visualize( + n, ds, f"sigmoid_{opt}", sigmoid_comparison, + None if omit_all_captions else f"{n} on {ds} - Sigmoid-like AFs", + bw=True, omit_af_names=True + ) + + +if __name__ == "__main__": + main() diff --git a/post_experiment/show_progress_summary.py b/post_experiment/show_progress_summary.py new file mode 100644 index 0000000..3a521e5 --- /dev/null +++ b/post_experiment/show_progress_summary.py @@ -0,0 +1,454 @@ +#!/usr/bin/env python3 + +import itertools +import csv + +from typing import Sequence, Tuple, List, NamedTuple, Union, Generator, Dict +from decimal import Decimal + +from misc import ProgressElement +from misc import NetInfo +from misc import get_file_name_stat, get_file_name_stat_table + + +def load_results(file_path: str) -> List[ProgressElement]: + with open(file_path, 'r') as f: + reader = csv.DictReader(f) + return list(reader) + + +def analyze_network(net_info: NetInfo) -> Tuple[int, float, float]: + """ + TBD + + :param net_info: TBD + :return: best_epoch, best_accuracy, avg_it_sec + """ + file_path = get_file_name_stat(*net_info) + results = load_results(file_path) + + max_acc = -1.0 + max_pos = -1 + pos = 0 + total_duration = 0 + # best_duration = -1.0 + pos_offset = net_info.epoch - len(results) + + for el in results: + pos += 1 + acc = Decimal(el["test_acc"]) * 100 + total_duration += Decimal(el["duration"]) + + if acc > max_acc: + max_acc = acc + max_pos = pos + pos_offset + # best_duration = total_duration + + avg_it_time = total_duration / len(results) + return max_pos, max_acc, round(avg_it_time, 2) + + +class SummaryItem(NamedTuple): + net_type: str + af_cnn: str + af_ffn: str + best_acc: float + best_ep: int + avg_it_sec: float + tuned: bool + dspu4: bool + p24sl: bool + + +class SuperSummKey(NamedTuple): + net: str + net_type: str + act: str + tuned: bool + dspu4: bool + p24sl: bool + + +class SuperSummSubItem(NamedTuple): + best_acc: float + best_ep: int + + +class SuperSummItem(NamedTuple): + fmnist: SuperSummSubItem + cifar10: SuperSummSubItem + + +SuperSumm = Dict[SuperSummKey, SuperSummItem] + + +def gather_results(nets: Sequence[Union[Tuple, NetInfo]]) -> List[SummaryItem]: + results = [] + + for net in nets: + if not isinstance(net, NetInfo): + net = NetInfo(*net) + try: + best_ep, best_acc, duration = analyze_network(net) + except FileNotFoundError: + continue + except Exception as e: + print("Exception: {}, skipped".format(e)) + continue + + net_af_cnn = net.af_name_cnn if net.af_name_cnn else net.af_name + net_af_ffn = net.af_name + + results.append( + SummaryItem(net.net_type, net_af_cnn, net_af_ffn, best_acc, best_ep, + duration, net.fine_tuned, net.dspu4, net.p24sl) + ) + + return results + + +def prettify_net_type_short(net_type: str, fine_tuned: bool = False) -> str: + if net_type == "base": + net_type = "Base" + elif net_type == "ahaf": + net_type = "AHAF" + elif net_type == "ahaf_ffn": + net_type = "AHAF FFN" + elif net_type == "leaf": + net_type = "LEAF" + elif net_type == "leaf_ffn": + net_type = "LEAF FFN" + elif net_type == "fuzzy_ffn": + net_type = "Fuzzy" + else: + raise ValueError("Network type is not supported") + + if fine_tuned: + net_type = net_type + " tuned" + + return net_type + + +def prettify_net_type_long(net_name: str, net_type: str, fine_tuned: bool = False) -> str: + if net_type == "base": + net_type = f"Base {net_name}" + elif net_type == "ahaf": + net_type = f"{net_name} w/ AHAF" + elif net_type == "ahaf_ffn": + net_type = f"{net_name} w/ AHAF FFN" + elif net_type == "leaf": + net_type = f"{net_name} w/ LEAF" + elif net_type == "leaf_ffn": + net_type = f"{net_name} w/ LEAF FFN" + elif net_type == "fuzzy_ffn": + net_type = f"{net_name} w/ Fuzzy FFN" + else: + raise ValueError("Network type is not supported") + + if fine_tuned: + net_type = net_type + " fine-tuned" + + return net_type + + +def prettify_result(el: SummaryItem) -> Tuple: + net_type = prettify_net_type_short(el.net_type, el.tuned) + + if el.dspu4 and el.p24sl: + train_str = 'DSPT, P24Sl' + elif el.dspu4: + train_str = 'DSPT' + elif el.p24sl: + train_str = 'P24Sl' + else: + train_str = 'Stand.' + + return ( + net_type, el.af_cnn, el.af_ffn, train_str, + el.best_acc, el.best_ep, el.avg_it_sec + ) + + +def prettify_results( + results: Sequence[SummaryItem] +) -> Generator[Tuple, None, None]: + for el in results: + yield prettify_result(el) + + +def save_results_as_csv(results: List[SummaryItem], path: str): + with open(path, 'w') as f: + writer = csv.writer(f) + writer.writerow( + ("Type", "CNN AF", "FFN AF", "Tr. Alg.", "Accuracy, %", "Epoch", "It. time, s") + ) + writer.writerows(prettify_results(results)) + + +def summarize( + net_name: str, ds_name: str, net_group: str, + nets: Sequence[Union[Tuple, NetInfo]] +) -> List[SummaryItem]: + results = gather_results(nets) + + if results: + save_results_as_csv(results, get_file_name_stat_table( + net_name, ds_name, net_group, + nets[0].epoch, nets[0].patched, nets[0].fine_tuned + )) + + return results + + +def results_to_tex(results: SuperSumm, file_name: str, alt_style: bool = False): + header1 = """ +\\begin{table}[htbp] +\t\\caption{Best test set accuracy, up to 100 epochs} +\t\\label{table:tab1} +\t\\begin{tabular}{llllcccc} +\t\t\\toprule +\t\t& & & & \\multicolumn{2}{c}{Fashion-MNIST} & \\multicolumn{2}{c}{CIFAR-10} \\\\ +\t\t\\cmidrule(lr){5-6}\\cmidrule(lr){7-8} +\t\tNetwork & Activ. & Init. & Proc. & Acc. & Epoch & Acc. & Epoch \\\\ +\t\t\\midrule""" + + line_template1 = "\n\t\t{} & {} & {} & {} & {:.2f}\\% & {} & {:.2f}\\% & {} \\\\" + + footer1 = """ +\t\t\\bottomrule +\t\\end{tabular} +\\end{table} +""" + + header2 = """ +\\begin{table}[htbp] +\t\\caption{Best test set accuracy, up to 100 epochs} +\t\\label{table:tab1} +\t\\begin{tabular}{|p{30pt}|p{15pt}|p{15pt}|p{20pt}||p{20pt}|p{10pt}||p{20pt}|p{10pt}|} +\t\t\\hline +\t\t& & & & \\multicolumn{2}{c||}{\\textbf{F-MNIST}} & \\multicolumn{2}{c|}{\\textbf{CIFAR-10}} \\\\ +\t\t\\cline{5-6}\\cline{7-8} +\t\t\\textbf{Network} & \\textbf{Activ.} & \\textbf{Init.} & \\textbf{Proc.} & \\textbf{Acc.,\\%} & \\textbf{Ep.} & \\textbf{Acc.,\\%} & \\textbf{Ep.} \\\\ +\t\t\\hline""" + + line_template2 = "\n\t\t{} & {} & {} & {} & {:.2f}\\% & {} & {:.2f}\\% & {} \\\\ \\hline" + + footer2 = """ +\t\\end{tabular} +\\end{table} +""" + + if alt_style: + header = header2 + line_template = line_template2 + footer = footer2 + else: + header = header1 + line_template = line_template1 + footer = footer1 + + with open(file_name, 'w') as f: + f.write(header) + + for item in results: + net = item.net + var = item.net_type + + if item.dspu4 and item.p24sl: + trainer = "DSPT, P24Sl" + elif item.dspu4: + trainer = "DSPT" + elif item.p24sl: + trainer = "P24Sl" + else: + trainer = "Classic" + + base_act = item.act + + act_str = var.upper() if var != "base" else base_act + init_str = base_act if var != "base" else "N/A" + + item_value = results[item] + + line_str = line_template.format( + net, act_str, init_str, trainer, + item_value.fmnist.best_acc, item_value.fmnist.best_ep, + item_value.cifar10.best_acc, item_value.cifar10.best_ep + ) + f.write(line_str) + + f.write(footer) + + +def results_to_tex_sep_tables(net: str, results: SuperSumm, file_name: str, alt_style: bool = False): + header1 = """ +\\begin{table}[htbp] +\t\\caption{Best test set accuracy, up to 100 epochs} +\t\\label{table:tab1} +\t\\begin{tabular}{lllcccc} +\t\t\\toprule +\t\t& & & \\multicolumn{2}{c}{Fashion-MNIST} & \\multicolumn{2}{c}{CIFAR-10} \\\\ +\t\t\\cmidrule(lr){4-5}\\cmidrule(lr){6-7} +\t\tActiv. & Init. & Proc. & Acc. & Epoch & Acc. & Epoch \\\\ +\t\t\\midrule""" + + line_template1 = "\n\t\t{} & {} & {} & {:.2f}\\% & {} & {:.2f}\\% & {} \\\\" + + footer1 = """ +\t\t\\bottomrule +\t\\end{tabular} +\\end{table} +""" + + header2 = """ +\\begin{table}[htbp] +\t\\caption{Best test set accuracy, up to 100 epochs} +\t\\label{table:tab1} +\t\\begin{tabular}{|p{20pt}|p{20pt}|p{45pt}|p{25pt}|p{10pt}|p{25pt}|p{10pt}|} +\t\t\\hline +\t\t& & & \\multicolumn{2}{c|}{\\textbf{F-MNIST}} & \\multicolumn{2}{c|}{\\textbf{CIFAR-10}} \\\\ +\t\t\\cline{4-5}\\cline{6-7} +\t\t\\textbf{Activ.} & \\textbf{Init.} & \\textbf{Procedure} & \\textbf{Acc.,\\%} & \\textbf{Ep.} & \\textbf{Acc.,\\%} & \\textbf{Ep.} \\\\ +\t\t\\hline""" + + line_template2 = "\n\t\t{} & {} & {} & {:.2f}\\% & {} & {:.2f}\\% & {} \\\\ \\hline" + + footer2 = """ +\t\\end{tabular} +\\end{table} +""" + + if alt_style: + header = header2 + line_template = line_template2 + footer = footer2 + else: + header = header1 + line_template = line_template1 + footer = footer1 + + with open(file_name, 'w') as f: + f.write(header) + + for item in results: + if net != item.net: + continue + + var = item.net_type + + if item.dspu4 and item.p24sl: + trainer = "DSPT, P24Sl" + elif item.dspu4: + trainer = "DSPT" + elif item.p24sl: + trainer = "P24Sl" + else: + trainer = "Classic" + + base_act = item.act if item.act != 'Sigmoid' else 'Sigm.' + + act_str = var.upper() if var != "base" else base_act + init_str = base_act if var != "base" else "N/A" + + item_value = results[item] + + line_str = line_template.format( + act_str, init_str, trainer, + item_value.fmnist.best_acc, item_value.fmnist.best_ep, + item_value.cifar10.best_acc, item_value.cifar10.best_ep + ) + f.write(line_str) + + f.write(footer) + + +def extend_super_summary( + net: str, ds: str, results: List[SummaryItem], ss: SuperSumm +): + for result in results: + if result.af_cnn == result.af_ffn: + base_af = result.af_cnn + else: + base_af = f'{result.af_cnn}, {result.af_ffn}' + + key = SuperSummKey(net, result.net_type, base_af, result.tuned, result.dspu4, result.p24sl) + if key not in ss: + # initialize this combination + ss[key] = SuperSummItem( + SuperSummSubItem(-1.0, -1), + SuperSummSubItem(-1.0, -1), + ) + + if ds == 'F-MNIST': + ss[key] = SuperSummItem( + SuperSummSubItem(result.best_acc, result.best_ep), + ss[key].cifar10 + ) + elif ds == 'CIFAR-10': + ss[key] = SuperSummItem( + ss[key].fmnist, + SuperSummSubItem(result.best_acc, result.best_ep) + ) + + +def main(): + net_names = ["LeNet-5", "KerasNet"] + ds_names = ["F-MNIST", "CIFAR-10"] + net_ds_combinations = itertools.product(net_names, ds_names) + opt = "adam" + all_lin_uns = {} # type: SuperSumm + all_bou_fns = {} # type: SuperSumm + + for n, ds in net_ds_combinations: + nets_vs_ahaf = [ + NetInfo(n, "base", ds, "ReLU", 100, patched=False, fine_tuned=False, opt_name=opt), + NetInfo(n, "ahaf", ds, "ReLU", 100, patched=False, fine_tuned=False, opt_name=opt), + NetInfo(n, "ahaf", ds, "ReLU", 100, patched=False, fine_tuned=False, dspu4=True, opt_name=opt), + NetInfo(n, "leaf", ds, "ReLU", 100, patched=False, fine_tuned=False, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "ReLU", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=True, opt_name=opt), + NetInfo(n, "base", ds, "SiLU", 100, patched=False, fine_tuned=False, opt_name=opt), + NetInfo(n, "ahaf", ds, "SiLU", 100, patched=False, fine_tuned=False, opt_name=opt), + NetInfo(n, "ahaf", ds, "SiLU", 100, patched=False, fine_tuned=False, dspu4=True, opt_name=opt), + NetInfo(n, "leaf", ds, "SiLU", 100, patched=False, fine_tuned=False, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "SiLU", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=True, opt_name=opt), + ] + + lin_uns_results = summarize(n, ds, f"lin_un_{opt}", nets_vs_ahaf) + extend_super_summary(n, ds, lin_uns_results, all_lin_uns) + + nets_vs_fuzzy = [ + NetInfo(n, "base", ds, "Tanh", 100, patched=False, fine_tuned=False, opt_name=opt), + NetInfo(n, "leaf", ds, "Tanh", 100, patched=False, fine_tuned=False, dspu4=False, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "Tanh", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=True, opt_name=opt), + NetInfo(n, "base", ds, "Sigmoid", 100, patched=False, fine_tuned=False, opt_name=opt), + NetInfo(n, "leaf", ds, "Sigmoid", 100, patched=False, fine_tuned=False, dspu4=False, p24sl=True, opt_name=opt), + NetInfo(n, "leaf", ds, "Sigmoid", 100, patched=False, fine_tuned=False, dspu4=True, p24sl=True, opt_name=opt), + ] + + bou_fns_results = summarize(n, ds, f"bou_fn_{opt}", nets_vs_fuzzy) + extend_super_summary(n, ds, bou_fns_results, all_bou_fns) + + results_to_tex_sep_tables( + 'LeNet-5', all_lin_uns, + f'runs/lenet5_all_summary_lin_un_{opt}_100ep_stat_table.tex', + alt_style=True + ) + results_to_tex_sep_tables( + 'KerasNet', all_lin_uns, + f'runs/kerasnet_all_summary_lin_un_{opt}_100ep_stat_table.tex', + alt_style=True + ) + results_to_tex_sep_tables( + 'LeNet-5', all_bou_fns, + f'runs/lenet5_all_summary_bou_fn_{opt}_100ep_stat_table.tex', + alt_style=True + ) + results_to_tex_sep_tables( + 'KerasNet', all_bou_fns, + f'runs/kerasnet_all_summary_bou_fn_{opt}_100ep_stat_table.tex', + alt_style=True + ) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4717743 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +torchinfo~=1.7.2 +torch~=2.0.0 +torchvision~=0.15.1 +matplotlib~=3.3.2 +cycler~=0.10.0 \ No newline at end of file diff --git a/run_experiment.sh b/run_experiment.sh new file mode 100755 index 0000000..5e9c69b --- /dev/null +++ b/run_experiment.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +export CUBLAS_WORKSPACE_CONFIG=:4096:8 +export PYTHONPATH=".:$PYTHONPATH" + +SRGK_VENV_PYTHON="./venv/bin/python3" + +if [[ -f "$SRGK_VENV_PYTHON" ]]; then + SRGK_PYEXEC="$SRGK_VENV_PYTHON" +else + SRGK_PYEXEC="python3" +fi + + +"$SRGK_PYEXEC" "$@" diff --git a/runs/.gitkeep b/runs/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/check_ahaf_manual_grad.py b/tests/check_ahaf_manual_grad.py new file mode 100644 index 0000000..cb6c91e --- /dev/null +++ b/tests/check_ahaf_manual_grad.py @@ -0,0 +1,21 @@ +import torch +from torch.autograd import gradcheck + +from adaptive_afs.cont.ahaf import _ahaf + + +def main(): + w = 16 + h = 21 + + ins = ( + torch.randn(w, h, dtype=torch.double, requires_grad=True), + torch.randn(w, h, dtype=torch.double, requires_grad=True), + torch.randn(w, h, dtype=torch.double, requires_grad=True), + ) + test = gradcheck(_ahaf, ins, eps=1e-6, atol=1e-4) + print(test) + + +if __name__ == "__main__": + main() diff --git a/tests/check_kerasnet_aaf_init.py b/tests/check_kerasnet_aaf_init.py new file mode 100644 index 0000000..e908cf1 --- /dev/null +++ b/tests/check_kerasnet_aaf_init.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +import torchinfo + +from nns_aaf import KerasNetAaf, AfDefinition + + +def print_act_functions(net: KerasNetAaf): + print(net.activations) + + +def main(): + nn_defs = [ + (None, None), # expected: all ReLU + (AfDefinition("SiLU"), AfDefinition("SiLU")), # expected: all SiLU + + # expected: SiLU in CNN, adaptive SiLU in FFN, range: -10...+10 + ( + AfDefinition("SiLU"), + AfDefinition( + "SiLU", AfDefinition.AfType.ADA_AHAF + ) + ), + + # expected: SiLU in CNN, HardTanh in FFN, range: -10...+10 + ( + AfDefinition("SiLU"), + AfDefinition( + "HardTanh", AfDefinition.AfType.TRAD, + AfDefinition.AfInterval(-10.0, +10.0) + ) + ), + + # expected: SiLU in CNN, Fuzzy HardTanh in FFN, range: -10...+10 + ( + AfDefinition("SiLU"), + AfDefinition( + "HardTanh", AfDefinition.AfType.ADA_FUZZ, + AfDefinition.AfInterval(-10.0, +10.0, n_segments=12) + ) + ), + + # expected: + # AHAF as SiLU in CNN, + # Fuzzy Sigmoid in FFN, range: -3...+3, + ( + AfDefinition("SiLU", AfDefinition.AfType.ADA_AHAF), + AfDefinition( + "Sigmoid", AfDefinition.AfType.ADA_FUZZ, + AfDefinition.AfInterval(-3.0, +3.0, n_segments=12) + ) + ), + + # expected: SiLU in CNN, AHAF as SiLU in FFN + ( + AfDefinition("SiLU", AfDefinition.AfType.TRAD), + AfDefinition("SiLU", AfDefinition.AfType.ADA_AHAF), + ), + + # expected: all LEAF as SiLU + ( + AfDefinition("SiLU", AfDefinition.AfType.ADA_LEAF), + AfDefinition("SiLU", AfDefinition.AfType.ADA_LEAF), + ) + ] + + batch_size = 64 + image_dim = (3, 32, 32) + input_size = (batch_size, *image_dim) + + for d in nn_defs: + net = KerasNetAaf(flavor='CIFAR10', af_conv=d[0], af_fc=d[1]) + torchinfo.summary(net, input_size=input_size) + print_act_functions(net) + + +if __name__ == "__main__": + main() diff --git a/tests/check_leaf_as_relu_grads.py b/tests/check_leaf_as_relu_grads.py new file mode 100644 index 0000000..f5aeca6 --- /dev/null +++ b/tests/check_leaf_as_relu_grads.py @@ -0,0 +1,40 @@ +import torch + + +def main(): + torch.manual_seed(42) + torch.use_deterministic_algorithms(mode=True) + + u = torch.randn(1, 99) + p1 = 1.0 + p2 = 0.0 + p3 = 1e5 + p4 = 0.0 + + print(u) + + grad_u = ( + (p1 * torch.sigmoid(p3 * u)) + + + (p1 * u + p2) + * torch.sigmoid(p3 * u) + * torch.sigmoid(-p3 * u) + * p3) + + grad_p1 = (u * torch.sigmoid(p3 * u)) + grad_p2 = torch.sigmoid(p3 * u) + grad_p3 = ( + (p1 * u + p2) + * torch.sigmoid(p3 * u) + * torch.sigmoid(-p3 * u) + * u + ) + + print(grad_u) + print(grad_p1) + print(grad_p2) + print(grad_p3) + + +if __name__ == "__main__": + main() diff --git a/tests/check_leaf_manual_grad.py b/tests/check_leaf_manual_grad.py new file mode 100644 index 0000000..3e820db --- /dev/null +++ b/tests/check_leaf_manual_grad.py @@ -0,0 +1,23 @@ +import torch +from torch.autograd import gradcheck + +from adaptive_afs.cont.leaf import _leaf + + +def main(): + w = 16 + h = 21 + + ins = ( + torch.randn(w, h, dtype=torch.double, requires_grad=True), + torch.randn(w, h, dtype=torch.double, requires_grad=True), + torch.randn(w, h, dtype=torch.double, requires_grad=True), + torch.randn(w, h, dtype=torch.double, requires_grad=True), + torch.randn(w, h, dtype=torch.double, requires_grad=True), + ) + test = gradcheck(_leaf, ins, eps=1e-6, atol=1e-4) + print(test) + + +if __name__ == "__main__": + main() diff --git a/tests/check_lenet_aaf_init.py b/tests/check_lenet_aaf_init.py new file mode 100644 index 0000000..3786e14 --- /dev/null +++ b/tests/check_lenet_aaf_init.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +import torchinfo + +from nns_aaf import LeNetAaf, AfDefinition + + +def print_act_functions(net: LeNetAaf): + print(net.activations) + + +def main(): + nn_defs = [ + (None, None), # expected: all ReLU + (AfDefinition("SiLU"), AfDefinition("SiLU")), # expected: all SiLU + + # expected: SiLU in CNN, adaptive SiLU in FFN, range: -10...+10 + ( + AfDefinition("SiLU"), + AfDefinition( + "SiLU", AfDefinition.AfType.ADA_AHAF + ) + ), + + # expected: SiLU in CNN, HardTanh in FFN, range: -10...+10 + ( + AfDefinition("SiLU"), + AfDefinition( + "HardTanh", AfDefinition.AfType.TRAD, + AfDefinition.AfInterval(-10.0, +10.0) + ) + ), + + # expected: SiLU in CNN, Fuzzy HardTanh in FFN, range: -10...+10 + ( + AfDefinition("SiLU"), + AfDefinition( + "HardTanh", AfDefinition.AfType.ADA_FUZZ, + AfDefinition.AfInterval(-10.0, +10.0, n_segments=12) + ) + ), + + # expected: + # AHAF as SiLU in CNN, + # Fuzzy Sigmoid in FFN, range: -3...+3, + ( + AfDefinition("SiLU", AfDefinition.AfType.ADA_AHAF), + AfDefinition( + "Sigmoid", AfDefinition.AfType.ADA_FUZZ, + AfDefinition.AfInterval(-3.0, +3.0, n_segments=12) + ) + ), + + # expected: SiLU in CNN, AHAF as SiLU in FFN + ( + AfDefinition("SiLU", AfDefinition.AfType.TRAD), + AfDefinition("SiLU", AfDefinition.AfType.ADA_AHAF), + ), + + # expected: all LEAF as SiLU + ( + AfDefinition("SiLU", AfDefinition.AfType.ADA_LEAF), + AfDefinition("SiLU", AfDefinition.AfType.ADA_LEAF), + ) + ] + + batch_size = 64 + image_dim = (1, 28, 28) + input_size = (batch_size, *image_dim) + + for d in nn_defs: + net = LeNetAaf(flavor='F-MNIST', af_conv=d[0], af_fc=d[1]) + torchinfo.summary(net, input_size=input_size) + print_act_functions(net) + + +if __name__ == "__main__": + main()