Skip to content

Commit

Permalink
fixes to nn infra
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed Feb 12, 2024
1 parent 75c610b commit 00a1aa2
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion py/rvspecfit/nn/RVSInterpolator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import NNInterpolator
from . import NNInterpolator
import torch
import scipy.spatial
import numpy as np
Expand Down
6 changes: 3 additions & 3 deletions py/rvspecfit/nn/train_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sklearn.decomposition as skde
import numpy as np
import torch
from NNInterpolator import Mapper, NNInterpolator
from .NNInterpolator import Mapper, NNInterpolator

git_rev = rvspecfit.__version__

Expand Down Expand Up @@ -323,10 +323,10 @@ def main(args):
'log_spec': True,
'logstep': True,
'module': 'rvspecfit.nn.RVSInterpolator',
'class_name': 'rvspecfit.nn.RVSInterpolator',
'class_name': 'RVSInterpolator',
'device': device_name,
'class_kwargs': kwargs,
'outside_class_name': 'rvspecfit.nn.OutsideInterpolator',
'outside_class_name': 'OutsideInterpolator',
'outside_kwargs': dict(pts=vecs),
'nn_file': finalfile,
'revision': revision,
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def read(fname):
license="BSD",
keywords="stellar spectra radial velocity",
url="http://github.com/segasai/rvspecfit",
packages=['rvspecfit', 'rvspecfit/desi', 'rvspecfit/weave'],
packages=[
'rvspecfit', 'rvspecfit/desi', 'rvspecfit/weave', 'rvspecfit/nn'
],
scripts=[fname for fname in glob.glob(os.path.join('bin', '*'))],
zip_safe=False,
package_dir={'': 'py/'},
Expand Down

0 comments on commit 00a1aa2

Please sign in to comment.