Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Which package to use for ACA method? #2

Open
YuchiQiu opened this issue Jan 8, 2023 · 3 comments
Open

Which package to use for ACA method? #2

YuchiQiu opened this issue Jan 8, 2023 · 3 comments

Comments

@YuchiQiu
Copy link

YuchiQiu commented Jan 8, 2023

Very interesting work. I want to implement ACA method, should I use odesolver_mem instead?

Is the odesolver for the naive Neural ODEs using auto differentiation?

Would you have an example using it for training instead of using odesolver? It seems they have different usage.

Thanks.

@YuchiQiu YuchiQiu closed this as completed Jan 8, 2023
@YuchiQiu YuchiQiu reopened this Jan 8, 2023
@juntang-zhuang
Copy link
Owner

juntang-zhuang commented Jan 9, 2023

Sorry I don't get your question. Here's an example to use the odesolver for integration (forward pass) and train the parameters using gradients from backward pass. https://github.com/juntang-zhuang/TorchDiffEqPack/blob/master/test_code/three_body_problem.py

For naive solvers, you can use odesolve

from .odesolver import odesolve
.

For ACA, checkout odesolve_ajoint, less memory.
For MALI, checkout odesolve_adjoint_sym12 (lesser memory, roughly a constant w.r.t integration time)

from .odesolver_mem import odesolve_endtime, odesolve_adjoint, odesolve_adjoint_sym12

In my package you can use the three method above interchangeably in most cases, except they have different considerations for memory / accuracy tradeoff.

Auto diff is a totally independent notion from ODE solver. ODE solver just gives the numerical solution to ODE; auto diff back-propagates an operation, in this special case the operation happens to be "solving an ODE", which is the same notion of back-prop for "Matrix multiplication".

@YuchiQiu
Copy link
Author

Thanks for the explanation. In this example, naive solver odesolve was used. How can I use ACA instead to learn the data?

The example is running fine. But replacing odesolve by odesolve_adjoint for ACA results in errors. What else should need to be rewritten for the example? Here are errors:

Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_umd.py", line 198, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/yuchiqiu/Dropbox/data-driven_special_issues/sample_code/three_body.py", line 276, in <module>
    out = odesolve_adjoint(func, initial_condition, options=options)#, time_points=t_list)
  File "/Users/yuchiqiu/Documents/venv/lib/python3.6/site-packages/TorchDiffEqPack/odesolver_mem/adjoint.py", line 194, in odesolve_adjoint
    zs = Checkpointing_Adjoint.apply(*_y0, func, options['t0'], options['t1'], flat_params, options)
  File "/Users/yuchiqiu/Documents/venv/lib/python3.6/site-packages/TorchDiffEqPack/odesolver_mem/adjoint.py", line 40, in forward
    solver = odesolve_endtime(func, z0, options, return_solver=True, regenerate_graph = False)
  File "/Users/yuchiqiu/Documents/venv/lib/python3.6/site-packages/TorchDiffEqPack/odesolver_mem/odesolver_endtime.py", line 26, in odesolve_endtime
    solver = Dopri5(func,y0=z0, **hyperparams, **kwargs)
TypeError: type object got multiple values for keyword argument 'regenerate_graph'

@juntang-zhuang
Copy link
Owner

juntang-zhuang commented Jan 12, 2023

You can remove 'regenerate_graph' from options when using odesolve_adjoint.

This example might not be good, odesolve_adjoint typically supports only extract the value at time t1, so t_eval (which is a list of time points to extract the ODE value) is not supported. You can also check folder image_classification which uses memory efficient solvers. odesolve supports extraction of arbitrary time points.

PS: the "naive" method is not the true naive method ("true naive" is direct translate a numpy integrator directly into PyTorch), it's much more accurate and memory efficient that the true naive solver, where the stepsize search process is also back-propagated through.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants