Skip to content

Commit

Permalink
Updates to cutlass code to add offset nn interp
Browse files Browse the repository at this point in the history
  • Loading branch information
wi-re committed Aug 26, 2024
1 parent 70a0a84 commit 0ddec05
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 69 deletions.
114 changes: 66 additions & 48 deletions src/BasisConvolution/detail/basis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from .util import cpow, getDistancesRel
from .util import cpow, getDistancesRel, getDistancesRel_offset
import numpy as np

@torch.jit.script
Expand All @@ -18,55 +18,73 @@ def evalRBFSeries(n : int, x : torch.Tensor, which : str = 'linear', epsilon : f
res = torch.zeros_like(r)
if not adjustSpacing and not normalized:
if which == 'linear': res = torch.clamp(1. - r / epsilon,0,1)
if which == 'gaussian': res = torch.exp(-(epsilon * r)**2)
if which == 'multiquadric': res = torch.sqrt(1. + (epsilon * r) **2)
if which == 'inverse_quadric': res = 1. / ( 1 + (epsilon * r) **2)
if which == 'inverse_multiquadric': res = 1. / torch.sqrt(1. + (epsilon * r) **2)
if which == 'polyharmonic': res = torch.pow(r, k) if k % 2 == 1 else torch.pow(r,k-1) * torch.log(torch.pow(r,r))
if which == 'bump': res = torch.where(r < 1./epsilon, torch.exp(-1./(1- (epsilon * r)**2)), torch.zeros_like(r))
if which == 'cubic_spline': res = cpow(1-r/(epsilon * 1.),3) - 4. * cpow(1/2-r/(epsilon * 1.),3)
if which == 'quartic_spline': res = cpow(1-r/(epsilon * 1.),4) - 5 * cpow(3/5-r/(epsilon * 1.),4) + 10 * cpow(1/5-r/(epsilon * 1.),4)
if which == 'quintic_spline': res = cpow(1-r/(epsilon * 1.),5) - 6 * cpow(2/3-r/(epsilon * 1.),5) + 15 * cpow(1/3-r/(epsilon * 1.),5)
if which == 'wendland2': res = cpow(1 - r/(epsilon * 1.), 4) * (1 + 4 * r/(epsilon * 1.))
if which == 'wendland4': res = cpow(1 - r/(epsilon * 1.), 6) * (1 + 6 * r/(epsilon * 1.) + 35/3 * (r/(epsilon * 1.))**2)
if which == 'wendland6': res = cpow(1 - r/(epsilon * 1.), 8) * (1 + 8 * r/(epsilon * 1.) + 25 * (r/(epsilon * 1.)) **2 + 32 * (r * (epsilon * 1.))**3)
if which == 'poly6': res = cpow(1 - (r/epsilon)**2, 3)
if which == 'spiky': res = cpow(1 - r/epsilon, 3)
if which == 'square': res = torch.where(torch.logical_and(rRel > -0.5 * epsilon, rRel <= 0.5 * epsilon), torch.ones_like(r), torch.zeros_like(r))
if adjustSpacing and not normalized:
elif which == 'gaussian': res = torch.exp(-(epsilon * r)**2)
elif which == 'multiquadric': res = torch.sqrt(1. + (epsilon * r) **2)
elif which == 'inverse_quadric': res = 1. / ( 1 + (epsilon * r) **2)
elif which == 'inverse_multiquadric': res = 1. / torch.sqrt(1. + (epsilon * r) **2)
elif which == 'polyharmonic': res = torch.pow(r, k) if k % 2 == 1 else torch.pow(r,k-1) * torch.log(torch.pow(r,r))
elif which == 'bump': res = torch.where(r < 1./epsilon, torch.exp(-1./(1- (epsilon * r)**2)), torch.zeros_like(r))
elif which == 'cubic_spline': res = cpow(1-r/(epsilon * 1.),3) - 4. * cpow(1/2-r/(epsilon * 1.),3)
elif which == 'quartic_spline': res = cpow(1-r/(epsilon * 1.),4) - 5 * cpow(3/5-r/(epsilon * 1.),4) + 10 * cpow(1/5-r/(epsilon * 1.),4)
elif which == 'quintic_spline': res = cpow(1-r/(epsilon * 1.),5) - 6 * cpow(2/3-r/(epsilon * 1.),5) + 15 * cpow(1/3-r/(epsilon * 1.),5)
elif which == 'wendland2': res = cpow(1 - r/(epsilon * 1.), 4) * (1 + 4 * r/(epsilon * 1.))
elif which == 'wendland4': res = cpow(1 - r/(epsilon * 1.), 6) * (1 + 6 * r/(epsilon * 1.) + 35/3 * (r/(epsilon * 1.))**2)
elif which == 'wendland6': res = cpow(1 - r/(epsilon * 1.), 8) * (1 + 8 * r/(epsilon * 1.) + 25 * (r/(epsilon * 1.)) **2 + 32 * (r * (epsilon * 1.))**3)
elif which == 'poly6': res = cpow(1 - (r/epsilon)**2, 3)
elif which == 'spiky': res = cpow(1 - r/epsilon, 3)
elif which == 'square': res = torch.where(torch.logical_and(rRel > -0.5 * epsilon, rRel <= 0.5 * epsilon), torch.ones_like(r), torch.zeros_like(r))
elif which == 'square_offset':
rRel = getDistancesRel_offset(n, x, periodic)
res = torch.where(torch.logical_and(rRel > -0.5 * 1, rRel <= 0.5 * 1), torch.ones_like(r), torch.zeros_like(r))
else:
raise ValueError('Unknown basis function')
elif adjustSpacing and not normalized:
if which == 'linear': res = torch.clamp(1. - r / epsilon,0,1)
if which == 'gaussian': res = torch.exp(-(epsilon * r)**2)
if which == 'multiquadric': res = torch.sqrt(1. + (epsilon * r) **2)
if which == 'inverse_quadric': res = 1. / ( 1 + (epsilon * r) **2)
if which == 'inverse_multiquadric': res = 1. / torch.sqrt(1. + (epsilon * r) **2)
if which == 'polyharmonic': res = torch.pow(r, k) if k % 2 == 1 else torch.pow(r,k-1) * torch.log(torch.pow(r,r))
if which == 'bump': res = torch.where(r < 1./epsilon, torch.exp(-1./(1- (epsilon * r)**2)), torch.zeros_like(r))
if which == 'cubic_spline': res = cpow(1-r/(epsilon * 1.732051),3) - 4. * cpow(1/2-r/(epsilon * 1.732051),3)
if which == 'quartic_spline': res = cpow(1-r/(epsilon * 1.936492),4) - 5 * cpow(3/5-r/(epsilon * 1.936492),4) + 10 * cpow(1/5-r/(epsilon * 1.732051),4)
if which == 'quintic_spline': res = cpow(1-r/(epsilon * 2.121321),5) - 6 * cpow(2/3-r/(epsilon * 2.121321),5) + 15 * cpow(1/3-r/(epsilon * 2.121321),5)
if which == 'wendland2': res = cpow(1 - r/(epsilon * 1.620185), 4) * (1 + 4 * r/(epsilon * 1.620185))
if which == 'wendland4': res = cpow(1 - r/(epsilon * 1.936492), 6) * (1 + 6 * r/(epsilon * 1.936492) + 35/3 * (r/(epsilon * 1.936492))**2)
if which == 'wendland6': res = cpow(1 - r/(epsilon * 2.207940), 8) * (1 + 8 * r/(epsilon * 2.207940) + 25 * (r/(epsilon * 2.207940)) **2 + 32 * (r * (epsilon * 2.207940))**3)
if which == 'poly6': res = cpow(1 - (r/epsilon)**2, 3)
if which == 'spiky': res = cpow(1 - r/epsilon, 3)
if which == 'square': res = torch.where(torch.logical_and(rRel > -0.5 * epsilon, rRel <= 0.5 * epsilon), torch.ones_like(r), torch.zeros_like(r))
if not adjustSpacing and normalized:
elif which == 'gaussian': res = torch.exp(-(epsilon * r)**2)
elif which == 'multiquadric': res = torch.sqrt(1. + (epsilon * r) **2)
elif which == 'inverse_quadric': res = 1. / ( 1 + (epsilon * r) **2)
elif which == 'inverse_multiquadric': res = 1. / torch.sqrt(1. + (epsilon * r) **2)
elif which == 'polyharmonic': res = torch.pow(r, k) if k % 2 == 1 else torch.pow(r,k-1) * torch.log(torch.pow(r,r))
elif which == 'bump': res = torch.where(r < 1./epsilon, torch.exp(-1./(1- (epsilon * r)**2)), torch.zeros_like(r))
elif which == 'cubic_spline': res = cpow(1-r/(epsilon * 1.732051),3) - 4. * cpow(1/2-r/(epsilon * 1.732051),3)
elif which == 'quartic_spline': res = cpow(1-r/(epsilon * 1.936492),4) - 5 * cpow(3/5-r/(epsilon * 1.936492),4) + 10 * cpow(1/5-r/(epsilon * 1.732051),4)
elif which == 'quintic_spline': res = cpow(1-r/(epsilon * 2.121321),5) - 6 * cpow(2/3-r/(epsilon * 2.121321),5) + 15 * cpow(1/3-r/(epsilon * 2.121321),5)
elif which == 'wendland2': res = cpow(1 - r/(epsilon * 1.620185), 4) * (1 + 4 * r/(epsilon * 1.620185))
elif which == 'wendland4': res = cpow(1 - r/(epsilon * 1.936492), 6) * (1 + 6 * r/(epsilon * 1.936492) + 35/3 * (r/(epsilon * 1.936492))**2)
elif which == 'wendland6': res = cpow(1 - r/(epsilon * 2.207940), 8) * (1 + 8 * r/(epsilon * 2.207940) + 25 * (r/(epsilon * 2.207940)) **2 + 32 * (r * (epsilon * 2.207940))**3)
elif which == 'poly6': res = cpow(1 - (r/epsilon)**2, 3)
elif which == 'spiky': res = cpow(1 - r/epsilon, 3)
elif which == 'square': res = torch.where(torch.logical_and(rRel > -0.5 * epsilon, rRel <= 0.5 * epsilon), torch.ones_like(r), torch.zeros_like(r))
elif which == 'square_offset':
rRel = getDistancesRel_offset(n, x, periodic)
res = torch.where(torch.logical_and(rRel > -0.5 * 1, rRel <= 0.5 * 1), torch.ones_like(r), torch.zeros_like(r))
else:
raise ValueError('Unknown basis function')
elif not adjustSpacing and normalized:
if which == 'linear': res = torch.clamp(1. - r / 1,0,1)
if which == 'gaussian': res = torch.exp(-(0.9919394235466537 * r)**2)
if which == 'multiquadric': res = torch.sqrt(1. + (1 * r) **2)
if which == 'inverse_quadric': res = 1. / ( 1 + (1.1480214948705423 * r) **2)
if which == 'inverse_multiquadric': res = 1. / torch.sqrt(1. + (1.6382510991695163 * r) **2)
if which == 'polyharmonic': res = torch.pow(r, k) if k % 2 == 1 else torch.pow(r,k-1) * torch.log(torch.pow(r,r))
if which == 'bump': res = torch.where(r < 1./0.38739618954567656, torch.exp(-1./(1- (0.38739618954567656 * r)**2)), torch.zeros_like(r))
if which == 'cubic_spline': res = cpow(1-r/(epsilon * 2.009770395701026),3) - 4. * cpow(1/2-r/(epsilon * 2.009770395701026),3)
if which == 'quartic_spline': res = cpow(1-r/(epsilon * 2.4318514899853443),4) - 5 * cpow(3/5-r/(epsilon * 2.4318514899853443),4) + 10 * cpow(1/5-r/(epsilon * 2.4318514899853443),4)
if which == 'quintic_spline': res = cpow(1-r/(epsilon * 2.8903273082559844),5) - 6 * cpow(2/3-r/(epsilon * 2.8903273082559844),5) + 15 * cpow(1/3-r/(epsilon * 2.8903273082559844),5)
if which == 'wendland2': res = cpow(1 - r/(epsilon * 3.6238397655105032), 4) * (1 + 4 * r/(epsilon * 3.6238397655105032))
if which == 'wendland4': res = cpow(1 - r/(epsilon * 3.7338788470933073), 6) * (1 + 6 * r/(epsilon * 3.7338788470933073) + 35/3 * (r/(epsilon * 3.7338788470933073))**2)
if which == 'wendland6': res = cpow(1 - r/(epsilon * 1.3856863702979971), 8) * (1 + 8 * r/(epsilon * 1.3856863702979971) + 25 * (r/(epsilon * 1.3856863702979971)) **2 + 32 * (r * (epsilon * 1.3856863702979971))**3)
if which == 'poly6': res = cpow(1 - (r/ 2.6936980947728384)**2, 3)
if which == 'spiky': res = cpow(1 - r/3, 3)
if which == 'square': res = torch.where(torch.logical_and(rRel > -0.5 * 1, rRel <= 0.5 * 1), torch.ones_like(r), torch.zeros_like(r))
elif which == 'gaussian': res = torch.exp(-(0.9919394235466537 * r)**2)
elif which == 'multiquadric': res = torch.sqrt(1. + (1 * r) **2)
elif which == 'inverse_quadric': res = 1. / ( 1 + (1.1480214948705423 * r) **2)
elif which == 'inverse_multiquadric': res = 1. / torch.sqrt(1. + (1.6382510991695163 * r) **2)
elif which == 'polyharmonic': res = torch.pow(r, k) if k % 2 == 1 else torch.pow(r,k-1) * torch.log(torch.pow(r,r))
elif which == 'bump': res = torch.where(r < 1./0.38739618954567656, torch.exp(-1./(1- (0.38739618954567656 * r)**2)), torch.zeros_like(r))
elif which == 'cubic_spline': res = cpow(1-r/(epsilon * 2.009770395701026),3) - 4. * cpow(1/2-r/(epsilon * 2.009770395701026),3)
elif which == 'quartic_spline': res = cpow(1-r/(epsilon * 2.4318514899853443),4) - 5 * cpow(3/5-r/(epsilon * 2.4318514899853443),4) + 10 * cpow(1/5-r/(epsilon * 2.4318514899853443),4)
elif which == 'quintic_spline': res = cpow(1-r/(epsilon * 2.8903273082559844),5) - 6 * cpow(2/3-r/(epsilon * 2.8903273082559844),5) + 15 * cpow(1/3-r/(epsilon * 2.8903273082559844),5)
elif which == 'wendland2': res = cpow(1 - r/(epsilon * 3.6238397655105032), 4) * (1 + 4 * r/(epsilon * 3.6238397655105032))
elif which == 'wendland4': res = cpow(1 - r/(epsilon * 3.7338788470933073), 6) * (1 + 6 * r/(epsilon * 3.7338788470933073) + 35/3 * (r/(epsilon * 3.7338788470933073))**2)
elif which == 'wendland6': res = cpow(1 - r/(epsilon * 1.3856863702979971), 8) * (1 + 8 * r/(epsilon * 1.3856863702979971) + 25 * (r/(epsilon * 1.3856863702979971)) **2 + 32 * (r * (epsilon * 1.3856863702979971))**3)
elif which == 'poly6': res = cpow(1 - (r/ 2.6936980947728384)**2, 3)
elif which == 'spiky': res = cpow(1 - r/3, 3)
elif which == 'square': res = torch.where(torch.logical_and(rRel > -0.5 * 1, rRel <= 0.5 * 1), torch.ones_like(r), torch.zeros_like(r))
elif which == 'square_offset':
rRel = getDistancesRel_offset(n, x, periodic)
res = torch.where(torch.logical_and(rRel > -0.5 * 1, rRel <= 0.5 * 1), torch.ones_like(r), torch.zeros_like(r))
# print('square_offset')
else:
raise ValueError('Unknown basis function')
else:
raise ValueError('Normalized and adjusted spacing is not supported')

if normalized:
res = res / torch.sum(res, dim = 0)
Expand Down
Loading

0 comments on commit 0ddec05

Please sign in to comment.