Skip to content

Commit

Permalink
add the nn interpolators code
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed Feb 12, 2024
1 parent 07fb140 commit f766bd0
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 0 deletions.
133 changes: 133 additions & 0 deletions py/rvspecfit/nn/NNInterpolator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch.nn as tonn
import torch
from collections import OrderedDict
import numpy as np


class BigInterpolator(tonn.Module):

def __init__(self,
indim=None,
nlayers=None,
width=None,
npc=None,
npix=None,
nstack=None):
super(BigInterpolator, self).__init__()
self.stacks = tonn.ModuleList()
self.nstack = nstack
self.nlayers = nlayers
self.indim = indim
self.npix = npix
self.width = width
self.npc = npc
self.initLayers()

def initLayers(self):
for i in range(self.nstack):
curint = NNInterpolator(indim=self.indim,
nlayers=self.nlayers,
width=self.width,
npc=self.npc,
npix=self.npix)
self.stacks.append(curint)

# self.add_module('stack%d' % i, curint)

def forward(self, x):
mylist = [s(x) for s in self.stacks]
ret = torch.stack(mylist, dim=0).sum(dim=0)
return ret


class NNInterpolator(tonn.Module):

def __init__(self,
indim=None,
nlayers=None,
width=None,
npc=None,
npix=None):
super(NNInterpolator, self).__init__()
self.indim = indim
self.nlayers = nlayers
self.width = width
self.npc = npc
self.npix = npix
self.initLayers()

def initLayers(self):
# Layer declaration
shapes = [
(self.indim, self.width)
] + [(self.width, self.width)] * self.nlayers + [(self.width, self.npc)
]
# self.L0 = tonn.Linear(self.indim, self.width)
layer_dict = OrderedDict()
# NL = tonn.Tanh
# NL = tonn.LeakyReLU
NL = tonn.ReLU
NL = tonn.SiLU
# NL = tonn.CELU
# NL = tonn.LeakyReLU
""" sequence here is
* is (indim x width) layer with bias
* Nonlinearity
* (width x width) layer with no bias
* batchnorm
* nonlinearity
* The last 3 lements repeated nlayers times
* (width x npc) layer with bias
* nonlinearity
* (npc x npix) layers
"""
batchnorm_after_nl = True

for i in range(len(shapes)):
withbn = True
if i == 0 or i == len(shapes) - 1:
withbn = False
if batchnorm_after_nl:
bias = True
else:
# if batchnorm is just after linear
# bias is not needed
bias = not withbn
curl = tonn.Linear(shapes[i][0], shapes[i][1], bias=bias)

layer_dict['lin_%d' % i] = curl
if withbn:
curbn = tonn.BatchNorm1d(shapes[i][1])
if batchnorm_after_nl:
layer_dict['nl_%d' % i] = NL()
layer_dict['bn_%d' % i] = curbn
else:
layer_dict['bn_%d' % i] = curbn
layer_dict['nl_%d' % i] = NL()
else:
layer_dict['nl_%d' % i] = NL()

self.model = tonn.Sequential(layer_dict)
self.pc_layer = tonn.Linear(self.npc, self.npix)

def forward(self, curx):
curx = curx.view(-1, self.indim)
curx = self.pc_layer(self.model(curx))
return curx


class Mapper:

def __init__(self, M, S, log_ids=[0]):
self.M = M
self.S = S
self.log_ids = log_ids

def forward(self, x):
x1 = np.asarray(x, dtype=np.float32)
y = x1 * 1
for ii in self.log_ids:
y[..., ii] = np.log10(x1[..., ii])
return (y - self.M) / self.S
68 changes: 68 additions & 0 deletions py/rvspecfit/nn/RVSInterpolator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import NNInterpolator
import torch
import scipy.spatial
import numpy as np
import os

device_env_name = 'RVS_NN_DEVICE'
if device_env_name in os.environ:
device_env = os.environ[device_env_name]
else:
device_env = None


class RVSInterpolator:

def __init__(self, kwargs):
self.nni = NNInterpolator.NNInterpolator(**kwargs['class_kwargs'])

if device_env is None:
self.device = torch.device(kwargs['device'])
else:
self.device = torch.device(device_env)
self.nni.load_state_dict(
torch.load(kwargs['template_lib'] + '/' + kwargs['nn_file'],
map_location=self.device))
# self.device = list(self.nni.children())[0][0].pc_layer.weight.device

self.nni.to(self.device)
print('RVS NN interpolator device:', self.device)
self.nni.eval()

def __call__(self, x):
with torch.inference_mode():
ret = self.nni(
torch.tensor(x, dtype=torch.float32).to(
self.device)).cpu().detach().numpy().astype(np.float64)
# prevent formal overflow
return np.exp(np.clip(ret, -300, 300)).flatten()


class OutsideInterpolator:

def __init__(self, kwargs0):
kwargs = kwargs0['outside_kwargs']
pts = kwargs['pts']
# separate first two dims
# and last two dims
xdim2 = pts[:, :2]
ydim2 = pts[:, 2:]
xconv = scipy.spatial.ConvexHull(xdim2)
xvec = xdim2[xconv.vertices]
yconv = scipy.spatial.ConvexHull(ydim2)
yvec = ydim2[yconv.vertices]
self.xtriang = scipy.spatial.Delaunay(xvec)
self.ytriang = scipy.spatial.Delaunay(yvec)
self.tree = scipy.spatial.cKDTree(pts)

def __call__(self, p):
if self.xtriang.find_simplex(p[:2]) < 0 or self.ytriang.find_simplex(
p[2:]) < 0:
return self.tree.query(p, 4)[0].mean()
return 0

@staticmethod
def generate_pts(vecs):
return vecs
# conv = scipy.spatial.ConvexHull(vecs)
# return vecs[conv.vertices]

0 comments on commit f766bd0

Please sign in to comment.