-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #384 from es-ude/365-implement-inference-execution…
…-with-pytest 365 implement inference execution with pytest
- Loading branch information
Showing
67 changed files
with
2,794 additions
and
2,731 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
141 changes: 141 additions & 0 deletions
141
elasticai/creator/nn/fixed_point/conv1d/_testbench_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
from typing import Any, Callable | ||
|
||
import pytest | ||
import torch | ||
|
||
from elasticai.creator.vhdl.auto_wire_protocols.port_definitions import create_port | ||
from elasticai.creator.vhdl.design.ports import Port | ||
|
||
from ..number_converter import FXPParams, NumberConverter | ||
from .testbench import Conv1dTestbench | ||
from .design import Conv1dDesign | ||
|
||
|
||
class DummyConv1d: | ||
def __init__(self, fxp_params: FXPParams, in_channels: int, out_channels: int): | ||
self.name: str = "conv1d" | ||
self.kernel_size: int = 1 | ||
self.input_signal_length = 1 | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.port: Port = create_port( | ||
y_width=fxp_params.total_bits, | ||
x_width=fxp_params.total_bits, | ||
x_count=1, | ||
y_count=2, | ||
) | ||
|
||
|
||
def parameters_for_reported_content_parsing(fxp_params, input_expected_pairs): | ||
def add_expected_prefix_to_pairs(pairs): | ||
_converter_for_batch = NumberConverter( | ||
FXPParams(8, 0) | ||
) # max for 255 lines of inputs | ||
pairs_with_prefix = list() | ||
for i, (pairs_text, pairs_number) in enumerate(pairs): | ||
pairs_with_prefix.append(list()) | ||
pairs_with_prefix[i].append(list()) | ||
pairs_with_prefix[i].append(pairs_number) | ||
for batch_number, batch_channel_text in enumerate(pairs_text): | ||
for out_channel_text in batch_channel_text: | ||
for value_text in out_channel_text: | ||
pairs_with_prefix[i][0].append( | ||
f"result: {_converter_for_batch.integer_to_bits(batch_number)}," | ||
f" {value_text}" | ||
) | ||
return pairs_with_prefix | ||
|
||
pairs_with_prefix = [ | ||
(fxp_params, a, b) | ||
for a, b in add_expected_prefix_to_pairs(input_expected_pairs) | ||
] | ||
return pairs_with_prefix | ||
|
||
|
||
@pytest.fixture | ||
def create_uut() -> Callable[[FXPParams, int, int], Conv1dDesign]: | ||
def create(fxp_params, in_channels: int, out_channels: int) -> Conv1dDesign: | ||
return DummyConv1d(fxp_params, in_channels=in_channels, out_channels=out_channels) | ||
|
||
return create | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"fxp_params, reported, y", ( | ||
parameters_for_reported_content_parsing( | ||
fxp_params=FXPParams(total_bits=3, frac_bits=0), | ||
input_expected_pairs=[ | ||
([[["010"]]], [[[2.0]]]), | ||
([[["001", "010"]]], [[[1.0, 2.0]]]), | ||
([[["111", "001"]]], [[[-1.0, 1.0]]]), | ||
] | ||
) + | ||
parameters_for_reported_content_parsing( | ||
fxp_params=FXPParams(total_bits=4, frac_bits=1), | ||
input_expected_pairs=[ | ||
([[["0001", "1111"]]], [[[0.5, -0.5]]]), | ||
([[["0001", "0011"]], [["1000", "1111"]]], [[[0.5, 1.5]], [[-4.0, -0.5]]]), | ||
] | ||
) | ||
) | ||
) | ||
def test_parse_reported_content_one_out_channel(fxp_params, reported, y, create_uut): | ||
in_channels = None | ||
out_channels = 1 | ||
bench = Conv1dTestbench( | ||
name="conv1d_testbench", fxp_params=fxp_params, uut=create_uut(fxp_params, in_channels, out_channels) | ||
) | ||
print(reported) | ||
assert y == bench.parse_reported_content(reported) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"fxp_params, reported, y", ( | ||
parameters_for_reported_content_parsing( | ||
fxp_params=FXPParams(total_bits=3, frac_bits=0), | ||
input_expected_pairs=[ | ||
([[["010"],["010"]]], [[[2.0],[2.0]]]), | ||
([[["001", "010"], ["001", "010"]]], [[[1.0, 2.0], [1.0, 2.0]]]), | ||
([[["111", "001"], ["111", "001"]]], [[[-1.0, 1.0], [-1.0, 1.0]]]), | ||
] | ||
) + | ||
parameters_for_reported_content_parsing( | ||
fxp_params=FXPParams(total_bits=4, frac_bits=1), | ||
input_expected_pairs=[ | ||
([[["0001", "1111"], ["0001", "1111"]]], [[[0.5, -0.5], [0.5, -0.5]]]), | ||
([[["0001", "0011"], ["0001", "0011"]], [["1000", "1111"], ["1000", "1111"]]], | ||
[[[0.5, 1.5], [0.5, 1.5]], [[-4.0, -0.5], [-4.0, -0.5]]]), | ||
] | ||
) | ||
) | ||
) | ||
def test_parse_reported_content_two_out_channel(fxp_params, reported, y, create_uut): | ||
in_channels = None | ||
out_channels = 2 | ||
bench = Conv1dTestbench( | ||
name="conv1d_testbench", fxp_params=fxp_params, uut=create_uut(fxp_params, in_channels, out_channels) | ||
) | ||
print(reported) | ||
assert y == bench.parse_reported_content(reported) | ||
|
||
def test_input_preparation_with_one_in_channel(create_uut): | ||
fxp_params = FXPParams(total_bits=3, frac_bits=0) | ||
in_channels = 1 | ||
out_channels = None | ||
bench = Conv1dTestbench( | ||
name="bench_name", fxp_params=fxp_params, uut=create_uut(fxp_params, in_channels, out_channels), | ||
) | ||
input = torch.Tensor([[[1.0, 1.0]]]) | ||
expected = [{"x_0_0": "001", "x_0_1": "001"}] | ||
assert expected == bench.prepare_inputs(input.tolist()) | ||
|
||
def test_input_preparation_with_two_in_channel(create_uut): | ||
fxp_params = FXPParams(total_bits=3, frac_bits=0) | ||
in_channels = 1 | ||
out_channels = None | ||
bench = Conv1dTestbench( | ||
name="bench_name", fxp_params=fxp_params, uut=create_uut(fxp_params, in_channels, out_channels), | ||
) | ||
input = torch.Tensor([[[1.0, 1.0], [1.0, 2.0]]]) | ||
expected = [{"x_0_0": "001", "x_0_1": "001", "x_1_0": "001", "x_1_1": "010"}] | ||
assert expected == bench.prepare_inputs(input.tolist()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,55 @@ | ||
-- Dummy File for testing implementation of conv1d Design | ||
${total_bits} | ||
${frac_bits} | ||
${in_channels} | ||
${out_channels} | ||
${kernel_size} | ||
library ieee; | ||
use ieee.std_logic_1164.all; | ||
|
||
entity ${name} is | ||
port ( | ||
enable : in std_logic; | ||
clock : in std_logic; | ||
x_address : out std_logic_vector(${x_address_width}-1 downto 0); | ||
y_address : in std_logic_vector(${y_address_width}-1 downto 0); | ||
|
||
x : in std_logic_vector(${x_width}-1 downto 0); | ||
y : out std_logic_vector(${y_width}-1 downto 0); | ||
|
||
done : out std_logic | ||
); | ||
end; | ||
|
||
architecture rtl of ${name} is | ||
constant TOTAL_WIDTH : natural := ${x_width}; | ||
constant FRAC_WIDTH : natural := ${frac_width}; | ||
constant VECTOR_WIDTH : natural := ${vector_width}; | ||
constant KERNEL_SIZE : natural := ${kernel_size}; | ||
constant IN_CHANNELS : natural := ${in_channels}; | ||
constant OUT_CHANNELS : natural := ${out_channels}; | ||
constant X_ADDRESS_WIDTH : natural := ${x_address_width}; | ||
constant Y_ADDRESS_WIDTH : natural := ${y_address_width}; | ||
|
||
signal reset : std_logic; | ||
|
||
begin | ||
|
||
reset <= not enable; | ||
|
||
${name}_conv1d : entity work.conv1d_fxp_MAC_RoundToZero | ||
generic map( | ||
TOTAL_WIDTH => TOTAL_WIDTH, | ||
FRAC_WIDTH => FRAC_WIDTH, | ||
VECTOR_WIDTH => VECTOR_WIDTH, | ||
KERNEL_SIZE => KERNEL_SIZE, | ||
IN_CHANNELS => IN_CHANNELS, | ||
OUT_CHANNELS => OUT_CHANNELS, | ||
X_ADDRESS_WIDTH => X_ADDRESS_WIDTH, | ||
Y_ADDRESS_WIDTH => Y_ADDRESS_WIDTH | ||
) | ||
port map ( | ||
clock => clock, | ||
enable => enable, | ||
reset => reset, | ||
x => x, | ||
x_address => x_address, | ||
y => y, | ||
y_address => y_address, | ||
done => done | ||
); | ||
end rtl; |
Oops, something went wrong.