Skip to content

Commit

Permalink
disable the layer noise when training
Browse files Browse the repository at this point in the history
also allow flexible nonlinearity
  • Loading branch information
segasai committed Nov 6, 2024
1 parent 8d63a1d commit 593a725
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 47 deletions.
12 changes: 5 additions & 7 deletions py/rvspecfit/nn/NNInterpolator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.nn as tonn
import torch.nn.modules.activation as tonnact
from collections import OrderedDict
import numpy as np

Expand All @@ -11,14 +12,16 @@ def __init__(self,
width=None,
npc=None,
npix=None,
withbn=True):
withbn=True,
nonlinearity='SiLU'):
super(NNInterpolator, self).__init__()
self.indim = indim
self.nlayers = nlayers
self.width = width
self.npc = npc
self.npix = npix
self.withbn = withbn
self.nonlinearity = nonlinearity
self.initLayers()

def initLayers(self):
Expand All @@ -29,12 +32,7 @@ def initLayers(self):
]
# self.L0 = tonn.Linear(self.indim, self.width)
layer_dict = OrderedDict()
# NL = tonn.Tanh
# NL = tonn.LeakyReLU
NL = tonn.ReLU
NL = tonn.SiLU
# NL = tonn.CELU
# NL = tonn.LeakyReLU
NL = getattr(tonnact, self.nonlinearity)
""" sequence here is
* is (indim x width) layer with bias
* Nonlinearity
Expand Down
63 changes: 23 additions & 40 deletions py/rvspecfit/nn/train_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def main(args):
npc=npc,
nlayers=nlayers,
width=width,
npix=npix)
npix=npix,
withbn=False)
myint = NNInterpolator(**kwargs)

statefile_path = f'{directory}/tmp_state_{setup}.sav'
Expand Down Expand Up @@ -210,63 +211,55 @@ def main(args):
counter = 0 # global counter
deltat = 0
# divstep = 0
regul_eps = 0
minlr = 1e-8
batch_move = True
layer_noise = 0.1
layer_noise = 0
for i in range(2):
if i == 0:
pass
elif i == 1:
layer_noise = 0
batch = 10000
batch_move = False
print('final loop')
params = myint.parameters()
optim = torch.optim.Adam(params, lr=lr0)
# optim = torch.optim.SGD(params, lr=lr0, nesterov=True, momentum=0.1)
sched = getSchedOptim(optim)
while True:
tstart = time.time()
counter += 1
loss0Accum = 0
regulAccum = 0
lossAccum = 0
lossAccum00 = 0
lossAccum_noised = 0
if not batch_move:
optim.zero_grad()
for Tdat, Tvecs00 in Tbatchdat:
# noise perturbed vectors
Tvecs = Tvecs00 + torch.rand(
size=Tvecs00.size()).to(train_dev) * layer_noise
if batch_on_dev:
Tdat = Tdat.to(train_dev)
Tvecs = Tvecs.to(train_dev)
if batch_move:
optim.zero_grad()
# Rfinal = myint(Tvecs) * tSD_0 + tD_0
RfinalX = myint(Tvecs00) * tSD_0 + tD_0
if regul_eps > 0:
regul = regul_eps * torch.sum(
torch.linalg.vector_norm(
torch.diff(myint.pc_layer.weight.data *
tSD_0.view(npix, 1),
dim=0))) / npix / npc
Rfinal_noised = myint(Tvecs) * tSD_0 + tD_0
if layer_noise == 0:
Rfinal00 = Rfinal_noised
else:
regul = torch.tensor(0, device=train_dev)
# loss0 = torch.sum(
# ((Rfinal - Tdat))**2) / len(Tdat) / npix
# loss0 = torch.linalg.vector_norm(Rfinal - Tdat)
# / len(Tdat) / npix
loss0 = tofu.l1_loss(RfinalX, Tdat) / spread0
loss = loss0 + regul
Rfinal00 = myint(Tvecs00) * tSD_0 + tD_0

loss_noised = tofu.l1_loss(Rfinal_noised, Tdat) / spread0
loss00 = tofu.l1_loss(Rfinal00, Tdat) / spread0
if batch_move:
torch.autograd.backward(loss)
torch.autograd.backward(loss_noised)
optim.step()

lossAccum += loss * len(Tdat) * npix
loss0Accum += loss0 * len(Tdat) * npix
regulAccum += regul
lossAccum00 += loss00 * len(Tdat) * npix
lossAccum_noised += loss_noised * len(Tdat) * npix
if not batch_move:
torch.autograd.backward(lossAccum / nspec / npix)
torch.autograd.backward(lossAccum_noised / nspec / npix)
optim.step()
sched.step(lossAccum)
sched.step(lossAccum00)
if validation:
with torch.inference_mode():
val_loss = tofu.l1_loss(
Expand All @@ -275,24 +268,14 @@ def main(args):
val_loss = val_loss.detach().cpu().numpy()
else:
val_loss = 0
loss0_V = loss0Accum.detach().cpu().numpy() / dats.size
regul_V = regulAccum.detach().cpu().numpy()
loss_V = lossAccum00.detach().cpu().numpy() / dats.size
curlr = optim.param_groups[0]['lr']
print('it %d loss %.5f' % (counter, loss0_V),
'regul %.5f' % regul_V, 'val %.5f' % val_loss, 'lr', curlr,
'time', deltat)
loss_V = loss0_V + regul_V
print('it %d loss %.5f' % (counter, loss_V), 'val %.5f' % val_loss,
'lr', curlr, 'time', deltat)
# lastloss = loss_V
losses.append(loss_V)
# if counter > 10:
# # TEMP
# break
if curlr < minlr:
break
# if np.ptp(losses[-30:]) / loss_V < 1e-3 and counter > 30:

# print('break2')
# break

if counter % 32 == 0 and counter > 0:
print('saving')
Expand Down

0 comments on commit 593a725

Please sign in to comment.