-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
201 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import torch.nn as tonn | ||
import torch | ||
from collections import OrderedDict | ||
import numpy as np | ||
|
||
|
||
class BigInterpolator(tonn.Module): | ||
|
||
def __init__(self, | ||
indim=None, | ||
nlayers=None, | ||
width=None, | ||
npc=None, | ||
npix=None, | ||
nstack=None): | ||
super(BigInterpolator, self).__init__() | ||
self.stacks = tonn.ModuleList() | ||
self.nstack = nstack | ||
self.nlayers = nlayers | ||
self.indim = indim | ||
self.npix = npix | ||
self.width = width | ||
self.npc = npc | ||
self.initLayers() | ||
|
||
def initLayers(self): | ||
for i in range(self.nstack): | ||
curint = NNInterpolator(indim=self.indim, | ||
nlayers=self.nlayers, | ||
width=self.width, | ||
npc=self.npc, | ||
npix=self.npix) | ||
self.stacks.append(curint) | ||
|
||
# self.add_module('stack%d' % i, curint) | ||
|
||
def forward(self, x): | ||
mylist = [s(x) for s in self.stacks] | ||
ret = torch.stack(mylist, dim=0).sum(dim=0) | ||
return ret | ||
|
||
|
||
class NNInterpolator(tonn.Module): | ||
|
||
def __init__(self, | ||
indim=None, | ||
nlayers=None, | ||
width=None, | ||
npc=None, | ||
npix=None): | ||
super(NNInterpolator, self).__init__() | ||
self.indim = indim | ||
self.nlayers = nlayers | ||
self.width = width | ||
self.npc = npc | ||
self.npix = npix | ||
self.initLayers() | ||
|
||
def initLayers(self): | ||
# Layer declaration | ||
shapes = [ | ||
(self.indim, self.width) | ||
] + [(self.width, self.width)] * self.nlayers + [(self.width, self.npc) | ||
] | ||
# self.L0 = tonn.Linear(self.indim, self.width) | ||
layer_dict = OrderedDict() | ||
# NL = tonn.Tanh | ||
# NL = tonn.LeakyReLU | ||
NL = tonn.ReLU | ||
NL = tonn.SiLU | ||
# NL = tonn.CELU | ||
# NL = tonn.LeakyReLU | ||
""" sequence here is | ||
* is (indim x width) layer with bias | ||
* Nonlinearity | ||
* (width x width) layer with no bias | ||
* batchnorm | ||
* nonlinearity | ||
* The last 3 lements repeated nlayers times | ||
* (width x npc) layer with bias | ||
* nonlinearity | ||
* (npc x npix) layers | ||
""" | ||
batchnorm_after_nl = True | ||
|
||
for i in range(len(shapes)): | ||
withbn = True | ||
if i == 0 or i == len(shapes) - 1: | ||
withbn = False | ||
if batchnorm_after_nl: | ||
bias = True | ||
else: | ||
# if batchnorm is just after linear | ||
# bias is not needed | ||
bias = not withbn | ||
curl = tonn.Linear(shapes[i][0], shapes[i][1], bias=bias) | ||
|
||
layer_dict['lin_%d' % i] = curl | ||
if withbn: | ||
curbn = tonn.BatchNorm1d(shapes[i][1]) | ||
if batchnorm_after_nl: | ||
layer_dict['nl_%d' % i] = NL() | ||
layer_dict['bn_%d' % i] = curbn | ||
else: | ||
layer_dict['bn_%d' % i] = curbn | ||
layer_dict['nl_%d' % i] = NL() | ||
else: | ||
layer_dict['nl_%d' % i] = NL() | ||
|
||
self.model = tonn.Sequential(layer_dict) | ||
self.pc_layer = tonn.Linear(self.npc, self.npix) | ||
|
||
def forward(self, curx): | ||
curx = curx.view(-1, self.indim) | ||
curx = self.pc_layer(self.model(curx)) | ||
return curx | ||
|
||
|
||
class Mapper: | ||
|
||
def __init__(self, M, S, log_ids=[0]): | ||
self.M = M | ||
self.S = S | ||
self.log_ids = log_ids | ||
|
||
def forward(self, x): | ||
x1 = np.asarray(x, dtype=np.float32) | ||
y = x1 * 1 | ||
for ii in self.log_ids: | ||
y[..., ii] = np.log10(x1[..., ii]) | ||
return (y - self.M) / self.S |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import NNInterpolator | ||
import torch | ||
import scipy.spatial | ||
import numpy as np | ||
import os | ||
|
||
device_env_name = 'RVS_NN_DEVICE' | ||
if device_env_name in os.environ: | ||
device_env = os.environ[device_env_name] | ||
else: | ||
device_env = None | ||
|
||
|
||
class RVSInterpolator: | ||
|
||
def __init__(self, kwargs): | ||
self.nni = NNInterpolator.NNInterpolator(**kwargs['class_kwargs']) | ||
|
||
if device_env is None: | ||
self.device = torch.device(kwargs['device']) | ||
else: | ||
self.device = torch.device(device_env) | ||
self.nni.load_state_dict( | ||
torch.load(kwargs['template_lib'] + '/' + kwargs['nn_file'], | ||
map_location=self.device)) | ||
# self.device = list(self.nni.children())[0][0].pc_layer.weight.device | ||
|
||
self.nni.to(self.device) | ||
print('RVS NN interpolator device:', self.device) | ||
self.nni.eval() | ||
|
||
def __call__(self, x): | ||
with torch.inference_mode(): | ||
ret = self.nni( | ||
torch.tensor(x, dtype=torch.float32).to( | ||
self.device)).cpu().detach().numpy().astype(np.float64) | ||
# prevent formal overflow | ||
return np.exp(np.clip(ret, -300, 300)).flatten() | ||
|
||
|
||
class OutsideInterpolator: | ||
|
||
def __init__(self, kwargs0): | ||
kwargs = kwargs0['outside_kwargs'] | ||
pts = kwargs['pts'] | ||
# separate first two dims | ||
# and last two dims | ||
xdim2 = pts[:, :2] | ||
ydim2 = pts[:, 2:] | ||
xconv = scipy.spatial.ConvexHull(xdim2) | ||
xvec = xdim2[xconv.vertices] | ||
yconv = scipy.spatial.ConvexHull(ydim2) | ||
yvec = ydim2[yconv.vertices] | ||
self.xtriang = scipy.spatial.Delaunay(xvec) | ||
self.ytriang = scipy.spatial.Delaunay(yvec) | ||
self.tree = scipy.spatial.cKDTree(pts) | ||
|
||
def __call__(self, p): | ||
if self.xtriang.find_simplex(p[:2]) < 0 or self.ytriang.find_simplex( | ||
p[2:]) < 0: | ||
return self.tree.query(p, 4)[0].mean() | ||
return 0 | ||
|
||
@staticmethod | ||
def generate_pts(vecs): | ||
return vecs | ||
# conv = scipy.spatial.ConvexHull(vecs) | ||
# return vecs[conv.vertices] |