Skip to content

Commit

Permalink
test custom module
Browse files Browse the repository at this point in the history
  • Loading branch information
Saran-nns committed Sep 3, 2024
1 parent 6839fdb commit c993d1c
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 51 deletions.
Empty file added examples/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions examples/custom_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch


class MySigmoid(torch.autograd.Function):

@staticmethod
def forward(ctx, input):
output = 1 / (1 + torch.exp(-input))
ctx.save_for_backward(output)
return output

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
return grad_output * input * (1 - input)


class MSELoss(torch.autograd.Function):

@staticmethod
def forward(ctx, y_pred, y):
ctx.save_for_backward(y_pred, y)
return ((y_pred - y) ** 2).sum() / y_pred.shape[0]

@staticmethod
def backward(ctx, grad_output):
y_pred, y = ctx.saved_tensors
grad_input = 2 * (y_pred - y) / y_pred.shape[0]
return grad_input, None


class MyModel(torch.nn.Module):
def __init__(self, D_in, D_out):
super(MyModel, self).__init__()
self.w1 = torch.nn.Parameter(torch.randn(D_in, D_out), requires_grad=True)
self.sigmoid = MySigmoid.apply

def forward(self, x):
y_pred = self.sigmoid(x.mm(self.w1))
return y_pred
38 changes: 0 additions & 38 deletions examples/torch_lightning.py

This file was deleted.

13 changes: 10 additions & 3 deletions gradients/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,19 @@ def forward(self, model, param, eps):
def check_(self, anagrad, numgrad_plus, numgrad_minus):

numgrad = (numgrad_plus - numgrad_minus) / (2.0 * self.eps)

diff = torch.norm(anagrad - numgrad) / (
torch.norm(anagrad) + torch.norm(numgrad)
)
if diff > 1e-7:
print(f"Parameter {self.param} Relative difference {diff} Check Failed")

if diff > 1e-6:
print(
f"Parameter {self.param} Relative difference between analytical and numerical gradient is {diff}>1e-6: Check Failed"
)
else:
print(
f"Parameter {self.param} Relative difference between analytical and numerical gradient is {diff}<1e-6: Check Passed"
)

def check(self):
# Analytical gradient
Expand All @@ -135,4 +143,3 @@ def check(self):
grad_plus = self.forward(model, self.param, self.eps)
grad_minus = self.forward(model, self.param, -self.eps)
self.check_(ana_grad, grad_plus, grad_minus)

28 changes: 18 additions & 10 deletions test_gradients.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import unittest
from gradients.gradients import Gradient
from example import *
from examples import custom_module as cm
import torch

N, D_in, D_out = 10, 4, 3

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
model = Model(D_in, D_out)
mymodel = MyModel(D_in,D_out)
criterion = torch.nn.MSELoss(reduction='mean')
mycriterion = MSELoss.apply

# Construct model by instantiating the class defined above
mymodel = cm.MyModel(D_in, D_out)
# criterion = cm.MSELoss.apply
criterion = torch.nn.MSELoss(reduction="mean")

# Test custom build model
Gradient(mymodel, x, y, criterion, eps=1e-8)


class TestGradient(unittest.TestCase):

def testGradient(self):
self.assertRaises(Exception, Gradient(model,x,y,criterion,eps=1e-8))
self.assertRaises(Exception, Gradient(model,x,y,criterion,eps=1e-8))
self.assertRaises(Exception, Gradient(model,x,y,criterion,eps=1e-8))
self.assertRaises(Exception, Gradient(model,x,y,criterion,eps=1e-8))
self.assertRaises(Exception, Gradient(mymodel, x, y, criterion, eps=1e-8))


if __name__ == "__main__":
unittest.main()
unittest.main(argv=["first-arg-is-ignored"], exit=False)

0 comments on commit c993d1c

Please sign in to comment.