diff --git a/py/rvspecfit/nn/NNInterpolator.py b/py/rvspecfit/nn/NNInterpolator.py new file mode 100644 index 0000000..818c4c9 --- /dev/null +++ b/py/rvspecfit/nn/NNInterpolator.py @@ -0,0 +1,133 @@ +import torch.nn as tonn +import torch +from collections import OrderedDict +import numpy as np + + +class BigInterpolator(tonn.Module): + + def __init__(self, + indim=None, + nlayers=None, + width=None, + npc=None, + npix=None, + nstack=None): + super(BigInterpolator, self).__init__() + self.stacks = tonn.ModuleList() + self.nstack = nstack + self.nlayers = nlayers + self.indim = indim + self.npix = npix + self.width = width + self.npc = npc + self.initLayers() + + def initLayers(self): + for i in range(self.nstack): + curint = NNInterpolator(indim=self.indim, + nlayers=self.nlayers, + width=self.width, + npc=self.npc, + npix=self.npix) + self.stacks.append(curint) + + # self.add_module('stack%d' % i, curint) + + def forward(self, x): + mylist = [s(x) for s in self.stacks] + ret = torch.stack(mylist, dim=0).sum(dim=0) + return ret + + +class NNInterpolator(tonn.Module): + + def __init__(self, + indim=None, + nlayers=None, + width=None, + npc=None, + npix=None): + super(NNInterpolator, self).__init__() + self.indim = indim + self.nlayers = nlayers + self.width = width + self.npc = npc + self.npix = npix + self.initLayers() + + def initLayers(self): + # Layer declaration + shapes = [ + (self.indim, self.width) + ] + [(self.width, self.width)] * self.nlayers + [(self.width, self.npc) + ] + # self.L0 = tonn.Linear(self.indim, self.width) + layer_dict = OrderedDict() + # NL = tonn.Tanh + # NL = tonn.LeakyReLU + NL = tonn.ReLU + NL = tonn.SiLU + # NL = tonn.CELU + # NL = tonn.LeakyReLU + """ sequence here is + * is (indim x width) layer with bias + * Nonlinearity + + * (width x width) layer with no bias + * batchnorm + * nonlinearity + * The last 3 lements repeated nlayers times + + * (width x npc) layer with bias + * nonlinearity + * (npc x npix) layers + """ + batchnorm_after_nl = True + + for i in range(len(shapes)): + withbn = True + if i == 0 or i == len(shapes) - 1: + withbn = False + if batchnorm_after_nl: + bias = True + else: + # if batchnorm is just after linear + # bias is not needed + bias = not withbn + curl = tonn.Linear(shapes[i][0], shapes[i][1], bias=bias) + + layer_dict['lin_%d' % i] = curl + if withbn: + curbn = tonn.BatchNorm1d(shapes[i][1]) + if batchnorm_after_nl: + layer_dict['nl_%d' % i] = NL() + layer_dict['bn_%d' % i] = curbn + else: + layer_dict['bn_%d' % i] = curbn + layer_dict['nl_%d' % i] = NL() + else: + layer_dict['nl_%d' % i] = NL() + + self.model = tonn.Sequential(layer_dict) + self.pc_layer = tonn.Linear(self.npc, self.npix) + + def forward(self, curx): + curx = curx.view(-1, self.indim) + curx = self.pc_layer(self.model(curx)) + return curx + + +class Mapper: + + def __init__(self, M, S, log_ids=[0]): + self.M = M + self.S = S + self.log_ids = log_ids + + def forward(self, x): + x1 = np.asarray(x, dtype=np.float32) + y = x1 * 1 + for ii in self.log_ids: + y[..., ii] = np.log10(x1[..., ii]) + return (y - self.M) / self.S diff --git a/py/rvspecfit/nn/RVSInterpolator.py b/py/rvspecfit/nn/RVSInterpolator.py new file mode 100644 index 0000000..291dbd2 --- /dev/null +++ b/py/rvspecfit/nn/RVSInterpolator.py @@ -0,0 +1,68 @@ +import NNInterpolator +import torch +import scipy.spatial +import numpy as np +import os + +device_env_name = 'RVS_NN_DEVICE' +if device_env_name in os.environ: + device_env = os.environ[device_env_name] +else: + device_env = None + + +class RVSInterpolator: + + def __init__(self, kwargs): + self.nni = NNInterpolator.NNInterpolator(**kwargs['class_kwargs']) + + if device_env is None: + self.device = torch.device(kwargs['device']) + else: + self.device = torch.device(device_env) + self.nni.load_state_dict( + torch.load(kwargs['template_lib'] + '/' + kwargs['nn_file'], + map_location=self.device)) + # self.device = list(self.nni.children())[0][0].pc_layer.weight.device + + self.nni.to(self.device) + print('RVS NN interpolator device:', self.device) + self.nni.eval() + + def __call__(self, x): + with torch.inference_mode(): + ret = self.nni( + torch.tensor(x, dtype=torch.float32).to( + self.device)).cpu().detach().numpy().astype(np.float64) + # prevent formal overflow + return np.exp(np.clip(ret, -300, 300)).flatten() + + +class OutsideInterpolator: + + def __init__(self, kwargs0): + kwargs = kwargs0['outside_kwargs'] + pts = kwargs['pts'] + # separate first two dims + # and last two dims + xdim2 = pts[:, :2] + ydim2 = pts[:, 2:] + xconv = scipy.spatial.ConvexHull(xdim2) + xvec = xdim2[xconv.vertices] + yconv = scipy.spatial.ConvexHull(ydim2) + yvec = ydim2[yconv.vertices] + self.xtriang = scipy.spatial.Delaunay(xvec) + self.ytriang = scipy.spatial.Delaunay(yvec) + self.tree = scipy.spatial.cKDTree(pts) + + def __call__(self, p): + if self.xtriang.find_simplex(p[:2]) < 0 or self.ytriang.find_simplex( + p[2:]) < 0: + return self.tree.query(p, 4)[0].mean() + return 0 + + @staticmethod + def generate_pts(vecs): + return vecs + # conv = scipy.spatial.ConvexHull(vecs) + # return vecs[conv.vertices]