From 4ab4a03a878d55d6fcdd4bd672ce7edc0bdc618c Mon Sep 17 00:00:00 2001 From: "Sergey E. Koposov" Date: Sun, 21 Jul 2024 21:34:18 +0100 Subject: [PATCH] rearrange tests --- py/rvspecfit/nn/NNInterpolator.py | 36 ------------------------------- tests/make_templ_nn.sh | 1 - tests/test_01template_creation.py | 6 ++++-- 3 files changed, 4 insertions(+), 39 deletions(-) diff --git a/py/rvspecfit/nn/NNInterpolator.py b/py/rvspecfit/nn/NNInterpolator.py index 818c4c9..b93bbfd 100644 --- a/py/rvspecfit/nn/NNInterpolator.py +++ b/py/rvspecfit/nn/NNInterpolator.py @@ -4,42 +4,6 @@ import numpy as np -class BigInterpolator(tonn.Module): - - def __init__(self, - indim=None, - nlayers=None, - width=None, - npc=None, - npix=None, - nstack=None): - super(BigInterpolator, self).__init__() - self.stacks = tonn.ModuleList() - self.nstack = nstack - self.nlayers = nlayers - self.indim = indim - self.npix = npix - self.width = width - self.npc = npc - self.initLayers() - - def initLayers(self): - for i in range(self.nstack): - curint = NNInterpolator(indim=self.indim, - nlayers=self.nlayers, - width=self.width, - npc=self.npc, - npix=self.npix) - self.stacks.append(curint) - - # self.add_module('stack%d' % i, curint) - - def forward(self, x): - mylist = [s(x) for s in self.stacks] - ret = torch.stack(mylist, dim=0).sum(dim=0) - return ret - - class NNInterpolator(tonn.Module): def __init__(self, diff --git a/tests/make_templ_nn.sh b/tests/make_templ_nn.sh index 87abfc9..a2d3ef9 100755 --- a/tests/make_templ_nn.sh +++ b/tests/make_templ_nn.sh @@ -20,5 +20,4 @@ RVS_MAKE_CCF="$COV `command -v rvs_make_ccf`" $RVS_READ_GRID --prefix $TEMPLPREF --templdb $PREFIX/files.db $RVS_MAKE_INTERPOL --air --setup $BNAME --lambda0 $BLAM0 --lambda1 $BLAM1 --resol $BRESOL --step $BSTEP --templdb ${PREFIX}/files.db --oprefix ${PREFIX}/ --templprefix $TEMPLPREF --fixed_fwhm --wavefile $WAVEFILE --nthreads 1 - $RVS_MAKE_ND --dir ${PREFIX}/ --pca_init --npc 10 --cpu --setup $BNAME diff --git a/tests/test_01template_creation.py b/tests/test_01template_creation.py index 1c8e85b..4508546 100644 --- a/tests/test_01template_creation.py +++ b/tests/test_01template_creation.py @@ -31,8 +31,10 @@ def run_script(script): def test_scripts(): + run_script(path + '/make_templ_nn.sh') + run_script(path + '/make_templ.sh') + run_script(path + '/gen_test_templ.sh') + run_script(path + '/make_templ_regul.sh') - run_script(path + '/make_templ_nn.sh') run_script(path + '/gen_test_templ_grid.sh') - run_script(path + '/gen_test_templ.sh')