From 00a1aa2c06e2e471ed5204e8f3b13a5921457694 Mon Sep 17 00:00:00 2001 From: "Sergey E. Koposov" Date: Mon, 12 Feb 2024 20:53:49 +0000 Subject: [PATCH] fixes to nn infra --- py/rvspecfit/nn/RVSInterpolator.py | 2 +- py/rvspecfit/nn/train_interpolator.py | 6 +++--- setup.py | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/py/rvspecfit/nn/RVSInterpolator.py b/py/rvspecfit/nn/RVSInterpolator.py index 291dbd2..94a9dd3 100644 --- a/py/rvspecfit/nn/RVSInterpolator.py +++ b/py/rvspecfit/nn/RVSInterpolator.py @@ -1,4 +1,4 @@ -import NNInterpolator +from . import NNInterpolator import torch import scipy.spatial import numpy as np diff --git a/py/rvspecfit/nn/train_interpolator.py b/py/rvspecfit/nn/train_interpolator.py index fd8b4ce..70d395a 100644 --- a/py/rvspecfit/nn/train_interpolator.py +++ b/py/rvspecfit/nn/train_interpolator.py @@ -9,7 +9,7 @@ import sklearn.decomposition as skde import numpy as np import torch -from NNInterpolator import Mapper, NNInterpolator +from .NNInterpolator import Mapper, NNInterpolator git_rev = rvspecfit.__version__ @@ -323,10 +323,10 @@ def main(args): 'log_spec': True, 'logstep': True, 'module': 'rvspecfit.nn.RVSInterpolator', - 'class_name': 'rvspecfit.nn.RVSInterpolator', + 'class_name': 'RVSInterpolator', 'device': device_name, 'class_kwargs': kwargs, - 'outside_class_name': 'rvspecfit.nn.OutsideInterpolator', + 'outside_class_name': 'OutsideInterpolator', 'outside_kwargs': dict(pts=vecs), 'nn_file': finalfile, 'revision': revision, diff --git a/setup.py b/setup.py index fe5e96b..8e612b0 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,9 @@ def read(fname): license="BSD", keywords="stellar spectra radial velocity", url="http://github.com/segasai/rvspecfit", - packages=['rvspecfit', 'rvspecfit/desi', 'rvspecfit/weave'], + packages=[ + 'rvspecfit', 'rvspecfit/desi', 'rvspecfit/weave', 'rvspecfit/nn' + ], scripts=[fname for fname in glob.glob(os.path.join('bin', '*'))], zip_safe=False, package_dir={'': 'py/'},