diff --git a/tests/problems.py b/tests/problems.py index 8a5546c9..98252032 100644 --- a/tests/problems.py +++ b/tests/problems.py @@ -60,7 +60,7 @@ def y_exact(self, t): DEVICES.append('cuda') FIXED_METHODS = ('euler', 'midpoint', 'heun2', 'heun3', 'rk4', 'explicit_adams', 'implicit_adams') ADAMS_METHODS = ('explicit_adams', 'implicit_adams') -ADAPTIVE_METHODS = ('adaptive_heun', 'fehlberg2', 'bosh3', 'dopri5', 'dopri8') +ADAPTIVE_METHODS = ('adaptive_heun', 'fehlberg2', 'bosh3', 'tsit5', 'dopri5', 'dopri8') SCIPY_METHODS = ('scipy_solver',) METHODS = FIXED_METHODS + ADAPTIVE_METHODS + SCIPY_METHODS diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py index 07f8666a..15146502 100644 --- a/torchdiffeq/_impl/odeint.py +++ b/torchdiffeq/_impl/odeint.py @@ -7,6 +7,7 @@ from .fixed_grid import Euler, Midpoint, Heun2, Heun3, RK4 from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton from .dopri8 import Dopri8Solver +from .tsit5 import Tsit5Solver from .scipy_wrapper import ScipyWrapperODESolver from .misc import _check_inputs, _flat_to_shape from .interp import _interp_evaluate @@ -14,6 +15,7 @@ SOLVERS = { 'dopri8': Dopri8Solver, 'dopri5': Dopri5Solver, + 'tsit5': Tsit5Solver, 'bosh3': Bosh3Solver, 'fehlberg2': Fehlberg2, 'adaptive_heun': AdaptiveHeunSolver, diff --git a/torchdiffeq/_impl/tsit5.py b/torchdiffeq/_impl/tsit5.py new file mode 100644 index 00000000..4f4a2218 --- /dev/null +++ b/torchdiffeq/_impl/tsit5.py @@ -0,0 +1,82 @@ +import torch +from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver +# https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/lib/OrdinaryDiffEqTsit5/src/tsit_tableaus.jl +# https://github.com/patrick-kidger/diffrax/blob/14baa1edddcacf27c0483962b3c9cf2e86e6e5b6/diffrax/_solver/tsit5.py#L158 + +_TSITOURAS_TABLEAU = _ButcherTableau( + alpha=torch.tensor([ + 161 / 1000, + 327 / 1000, + 9 / 10, + .9800255409045096857298102862870245954942137979563024768854764293221195950761080302604, + 1, + 1 + ], dtype=torch.float64), + beta=[ + torch.tensor([161 / 1000], dtype=torch.float64), + torch.tensor([ + -.8480655492356988544426874250230774675121177393430391537369234245294192976164141156943e-2, + .3354806554923569885444268742502307746751211773934303915373692342452941929761641411569 + ], dtype=torch.float64), + torch.tensor([ + 2.897153057105493432130432594192938764924887287701866490314866693455023795137503079289, + -6.359448489975074843148159912383825625952700647415626703305928850207288721235210244366, + 4.362295432869581411017727318190886861027813359713760212991062156752264926097707165077, + ], dtype=torch.float64), + torch.tensor([ + 5.325864828439256604428877920840511317836476253097040101202360397727981648835607691791, + -11.74888356406282787774717033978577296188744178259862899288666928009020615663593781589, + 7.495539342889836208304604784564358155658679161518186721010132816213648793440552049753, + -.9249506636175524925650207933207191611349983406029535244034750452930469056411389539635e-1 + ], dtype=torch.float64), + torch.tensor([ + 5.861455442946420028659251486982647890394337666164814434818157239052507339770711679748, + -12.92096931784710929170611868178335939541780751955743459166312250439928519268343184452, + 8.159367898576158643180400794539253485181918321135053305748355423955009222648673734986, + -.7158497328140099722453054252582973869127213147363544882721139659546372402303777878835e-1, + -.2826905039406838290900305721271224146717633626879770007617876201276764571291579142206e-1 + ], dtype=torch.float64), + torch.tensor([ + .9646076681806522951816731316512876333711995238157997181903319145764851595234062815396e-1, + 1 / 100, + .4798896504144995747752495322905965199130404621990332488332634944254542060153074523509, + 1.379008574103741893192274821856872770756462643091360525934940067397245698027561293331, + -3.290069515436080679901047585711363850115683290894936158531296799594813811049925401677, + 2.324710524099773982415355918398765796109060233222962411944060046314465391054716027841 + ], dtype=torch.float64), + ], + c_sol=torch.tensor([ + .9468075576583945807478876255758922856117527357724631226139574065785592789071067303271e-1, + .9183565540343253096776363936645313759813746240984095238905939532922955247253608687270e-2, + .4877705284247615707855642599631228241516691959761363774365216240304071651579571959813, + 1.234297566930478985655109673884237654035539930748192848315425833500484878378061439761, + -2.707712349983525454881109975059321670689605166938197378763992255714444407154902012702, + 1.866628418170587035753719399566211498666255505244122593996591602841258328965767580089, + 1 / 66 + ], dtype=torch.float64), + c_error=torch.tensor([ + -1.780011052225771443378550607539534775944678804333659557637450799792588061629796e-03, + -8.164344596567469032236360633546862401862537590159047610940604670770447527463931e-04, + 7.880878010261996010314727672526304238628733777103128603258129604952959142646516e-03, + -1.44711007173262907537165147972635116720922712343167677619514233896760819649515e-01, + 5.823571654525552250199376106520421794260781239567387797673045438803694038950012e-01, + -4.580821059291869466616365188325542974428047279788398179474684434732070620889539e-01, + 1 / 66 + ], dtype=torch.float64), +) + +x = 1 / 2 +TSIT_C_MID = torch.tensor([ + -1.0530884977290216*x*(x-1.329989018975412)*(x*x-1.4364028541716351*x+0.7139816917074209), + 0.1017*x*x*(x*x-2.1966568338249754*x+1.2949852507374631), + 2.490627285651252793*x*x*(x*x-2.38535645472061657*x+1.57803468208092486), + -16.54810288924490272*(x-1.21712927295533244)*(x-0.61620406037800089)*x*x, + 47.37952196281928122*(x-1.203071208372362603)*(x-0.658047292653547382)*x*x, + -34.87065786149660974*(x-1.2)*(x-2/3)*x*x, + 2.5*(x-1)*(x-0.6)*x*x +], dtype=torch.float64) + +class Tsit5Solver(RKAdaptiveStepsizeODESolver): + order = 5 + tableau = _TSITOURAS_TABLEAU + mid = TSIT_C_MID