You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In JAX, transformations can be customized for user defined functions, for example customizing the JVP rule. This can be achieved via two ways
wrap your custom function as a primitive, and register jvp rules for the primitive.
use the decorator @jax.custom_jvp.
This is very useful because when you need a function with a customized gradient, for example, giving a discrete function a continuous gradient relaxation, it can be done by the register mechanism, the upper level constructs, e.g. jax.grad(f) are kept the same not matter what the jvp of f is.
Density Functional Theory
In order to perform elegant DFT, we would like to keep the high level program like the following
importautofd.operatorsasoimportdft.energyasedefwave_ansatz(param, r):
...
defenergy(crystal, param, occ):
psi=o.partial(wave_ansatz, args=(param,), argnums=(0,))
defpsi_to_rho(psi):
rho=o.compose(
lambdav: jnp.sum(occ*jnp.real(jnp.conj(v) *v)), psi
)
returnrhodefenergy_functional(psi):
# use psi to build rho, in a functionally differentiable way.rho=psi_to_rho(psi)
# compute the total energy using the functionalsetot= (
e.kinetic(psi) +e.hartree(rho) +e.external(rho, crystal) +jax_xc.energy.lda_x(rho)
)
returnetot# gradient descent to optimize crystal, param, occ etc.jax.grad(energy)(crystal, param, occ)
The problem with the elegant code.
One difficulty to achieve this elegance is the efficiency, if we write ugly code, we may optimize all the above energy calculations by hand deriving each terms and implement the derived formula. For example, kinetic energy of planewave is simply summation over the norm of G and k grids. Because applying laplace operator to planewaves only adds a constant factor, and that the psi is normalized and there's no need to compute the <psi|psi> because it is always 1. However, all this level of details are not available to autofd, resulting in an inefficient implementation when we follow the above code for DFT.
Why we want the elegancy?
In physics we often write the math with great simplicity, e.g. to get the energy levels, we simply solve the following eigen value problem.
It takes many steps to derive this non-implementable math into implementable math, where squiggly symbol that represents functional derivative are hand derived, the integrals are discretized, and fourier transforms are used.
Then why don't we just implement the derived formula than following the functional form? It is not just for the elegancy, but also for extensibility. For example, one thing that we often do is to linearize the energy at the current value of rho, and use it as an effective potential for computing energy bands for cystals. Which gives me headache already if we were to re-derive the math for constructing fock matrix, and implement them as FFT. However, if we could support the above syntax, we could have an easier way around.
Compiler is the way! Let's write elegant code and enjoy the cleanness of various tasks in DFT while relying on some compiling process convert the code to high performance. The optimizations in the compiler takes lots of rules here and there, therefore, we want to enable autofd to support custom rules when the user knows a more efficient implementation.
Custom rules for any operators
In JAX, we can customize the rules for the transpose operator, jvp operator etc. A straightforward extension is to support customized rules for all operators.
Again with kinetic energy as an example, say we have build a primitive for computing kinetic energies of wave functions.
# A custom primitive for kinetic energykinetic_p=core.Primitive("kinetic")
defkinetic(psi):
kinetic_p.bind(psi)
defkinetic_impl(psi):
# general implementation
One difficult question is how can we customize many different rules and how do they interfere with each other. Can we retain the custom kinetic rule when we first apply the JVP rule on psi? I need to study further before having an idea for this. (To be continued)
The text was updated successfully, but these errors were encountered:
Background
In JAX, transformations can be customized for user defined functions, for example customizing the JVP rule. This can be achieved via two ways
This is very useful because when you need a function with a customized gradient, for example, giving a discrete function a continuous gradient relaxation, it can be done by the register mechanism, the upper level constructs, e.g. jax.grad(f) are kept the same not matter what the jvp of f is.
Density Functional Theory
In order to perform elegant DFT, we would like to keep the high level program like the following
The problem with the elegant code.
One difficulty to achieve this elegance is the efficiency, if we write ugly code, we may optimize all the above energy calculations by hand deriving each terms and implement the derived formula. For example, kinetic energy of planewave is simply summation over the norm of G and k grids. Because applying laplace operator to planewaves only adds a constant factor, and that the psi is normalized and there's no need to compute the <psi|psi> because it is always 1. However, all this level of details are not available to autofd, resulting in an inefficient implementation when we follow the above code for DFT.
Why we want the elegancy?
In physics we often write the math with great simplicity, e.g. to get the energy levels, we simply solve the following eigen value problem.
where
It takes many steps to derive this non-implementable math into implementable math, where squiggly symbol that represents functional derivative are hand derived, the integrals are discretized, and fourier transforms are used.
Then why don't we just implement the derived formula than following the functional form? It is not just for the elegancy, but also for extensibility. For example, one thing that we often do is to linearize the energy at the current value of rho, and use it as an effective potential for computing energy bands for cystals. Which gives me headache already if we were to re-derive the math for constructing fock matrix, and implement them as FFT. However, if we could support the above syntax, we could have an easier way around.
How to keep elegancy and efficiency?
Compiler is the way! Let's write elegant code and enjoy the cleanness of various tasks in DFT while relying on some compiling process convert the code to high performance. The optimizations in the compiler takes lots of rules here and there, therefore, we want to enable autofd to support custom rules when the user knows a more efficient implementation.
Custom rules for any operators
In JAX, we can customize the rules for the
transpose
operator,jvp
operator etc. A straightforward extension is to support customized rules for all operators.Again with kinetic energy as an example, say we have build a primitive for computing kinetic energies of wave functions.
We can customize it for a specific wave function
Mixing different rules
One difficult question is how can we customize many different rules and how do they interfere with each other. Can we retain the custom kinetic rule when we first apply the JVP rule on psi? I need to study further before having an idea for this. (To be continued)
The text was updated successfully, but these errors were encountered: