Skip to content

Commit

Permalink
rearrange tests
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed Jul 21, 2024
1 parent 0b9f278 commit 4ab4a03
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 39 deletions.
36 changes: 0 additions & 36 deletions py/rvspecfit/nn/NNInterpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,6 @@
import numpy as np


class BigInterpolator(tonn.Module):

def __init__(self,
indim=None,
nlayers=None,
width=None,
npc=None,
npix=None,
nstack=None):
super(BigInterpolator, self).__init__()
self.stacks = tonn.ModuleList()
self.nstack = nstack
self.nlayers = nlayers
self.indim = indim
self.npix = npix
self.width = width
self.npc = npc
self.initLayers()

def initLayers(self):
for i in range(self.nstack):
curint = NNInterpolator(indim=self.indim,
nlayers=self.nlayers,
width=self.width,
npc=self.npc,
npix=self.npix)
self.stacks.append(curint)

# self.add_module('stack%d' % i, curint)

def forward(self, x):
mylist = [s(x) for s in self.stacks]
ret = torch.stack(mylist, dim=0).sum(dim=0)
return ret


class NNInterpolator(tonn.Module):

def __init__(self,
Expand Down
1 change: 0 additions & 1 deletion tests/make_templ_nn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,4 @@ RVS_MAKE_CCF="$COV `command -v rvs_make_ccf`"

$RVS_READ_GRID --prefix $TEMPLPREF --templdb $PREFIX/files.db
$RVS_MAKE_INTERPOL --air --setup $BNAME --lambda0 $BLAM0 --lambda1 $BLAM1 --resol $BRESOL --step $BSTEP --templdb ${PREFIX}/files.db --oprefix ${PREFIX}/ --templprefix $TEMPLPREF --fixed_fwhm --wavefile $WAVEFILE --nthreads 1

$RVS_MAKE_ND --dir ${PREFIX}/ --pca_init --npc 10 --cpu --setup $BNAME
6 changes: 4 additions & 2 deletions tests/test_01template_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ def run_script(script):


def test_scripts():
run_script(path + '/make_templ_nn.sh')

run_script(path + '/make_templ.sh')
run_script(path + '/gen_test_templ.sh')

run_script(path + '/make_templ_regul.sh')
run_script(path + '/make_templ_nn.sh')
run_script(path + '/gen_test_templ_grid.sh')
run_script(path + '/gen_test_templ.sh')

0 comments on commit 4ab4a03

Please sign in to comment.