Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to add Tsit5 solver? #194

Open
Saltsmart opened this issue Jan 23, 2022 · 2 comments · May be fixed by #261
Open

How to add Tsit5 solver? #194

Saltsmart opened this issue Jan 23, 2022 · 2 comments · May be fixed by #261

Comments

@Saltsmart
Copy link

Hello! I'm trying to use Neural ODE in a forecasting (extrapolating longer) problem. Tsit5 is a newly-published explicit method (see paper) and available solver for Julia Project DifferentialEquations.jl. A Forecasting notebook shows it's popular in this task.

I want to implement Tsit5 solver with torchdiffeq. It should be like this (coefficients unchanged from dopri5.py):

_TSITOURAS_TABLEAU = _ButcherTableau(
    alpha=torch.tensor([1 / 5, 3 / 10, 4 / 5, 8 / 9,
                       1., 1.], dtype=torch.float64),
    beta=[
        torch.tensor([1 / 5], dtype=torch.float64),
        torch.tensor([3 / 40, 9 / 40], dtype=torch.float64),
        torch.tensor([44 / 45, -56 / 15, 32 / 9], dtype=torch.float64),
        torch.tensor([19372 / 6561, -25360 / 2187, 64448 /
                     6561, -212 / 729], dtype=torch.float64),
        torch.tensor([9017 / 3168, -355 / 33, 46732 / 5247, 49 /
                     176, -5103 / 18656], dtype=torch.float64),
        torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -
                     2187 / 6784, 11 / 84], dtype=torch.float64),
    ],
    c_sol=torch.tensor([35 / 384, 0, 500 / 1113, 125 /
                       192, -2187 / 6784, 11 / 84, 0], dtype=torch.float64),
    c_error=torch.tensor([
        35 / 384 - 1951 / 21600,
        0,
        500 / 1113 - 22642 / 50085,
        125 / 192 - 451 / 720,
        -2187 / 6784 - -12231 / 42400,
        11 / 84 - 649 / 6300,
        -1. / 60.,
    ], dtype=torch.float64),
)

_TS_C_MID = torch.tensor([
    6025192743 / 30085553152 / 2, 0, 51252292925 /
    65400821598 / 2, -2691868925 / 45128329728 / 2,
    187940372067 / 1594534317056 / 2, -1776094331 /
    19743644256 / 2, 11237099 / 235043384 / 2
], dtype=torch.float64)


class Tsit5Solver(RKAdaptiveStepsizeODESolver):
    order = 5
    tableau = _TSITOURAS_TABLEAU
    mid = _TS_C_MID

It's easy to figure out alpha, beta and c_sol, but what is the meaning of c_error and _TS_C_MID?

@Saltsmart
Copy link
Author

I'm not familiar with julia so I've found a nim implement in numericalnim. It uses same coefficients as the paper mentioned.

proc TSIT54_step[T](f: ODEProc[T], t: float, y, FSAL: T, dt: float,
                     options: ODEoptions, ctx: NumContext[T]): (T, T, float, float) =
    ## Take a single timestep using TSIT54. Only for internal use.
    const
        c2 = 0.161
        c3 = 0.327
        c4 = 0.9
        c5 = 0.9800255409045097
        c6 = 1.0
        c7 = 1.0
        a21 = 0.161
        a31 = -0.008480655492356989
        a32 = 0.335480655492357
        a41 = 2.8971530571054935
        a42 = -6.359448489975075
        a43 = 4.3622954328695815
        a51 = 5.325864828439257
        a52 = -11.748883564062828
        a53 = 7.4955393428898365
        a54 = -0.09249506636175525
        a61 = 5.86145544294642
        a62 = -12.92096931784711
        a63 = 8.159367898576159
        a64 = -0.071584973281401
        a65 = -0.028269050394068383
        a71 = 0.09646076681806523
        a72 = 0.01
        a73 = 0.4798896504144996
        a74 = 1.379008574103742
        a75 = -3.290069515436081
        a76 = 2.324710524099774
        # Fifth order
        b1 = a71
        b2 = a72
        b3 = a73
        b4 = a74
        b5 = a75
        b6 = a76
        # Fourth order
        bHat1 = -0.001780011052226
        bHat2 = -0.000816434459657
        bHat3 = 0.007880878010262
        bHat4 = -0.144711007173263
        bHat5 = 0.582357165452555
        bHat6 = -0.458082105929187
        bHat7 = 1.0/66.0
    let absTol = options.absTol
    let relTol = options.relTol
    let dtMax = options.dtMax
    let dtMin = options.dtMin
    var k1, k2, k3, k4, k5, k6, k7: T
    var yNew, yLow: T
    var error: float
    var limitCounter = 0
    var dt = dt
    commonAdaptiveMethodCode(yNew, error_y, order=5):
        k1 = FSAL
        k2 = f(t + dt*c2, y + dt * (a21 * k1), ctx)
        k3 = f(t + dt*c3, y + dt * (a31 * k1 + a32 * k2), ctx)
        k4 = f(t + dt*c4, y + dt * (a41 * k1 + a42 * k2 + a43 * k3), ctx)
        k5 = f(t + dt*c5, y + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4), ctx)
        k6 = f(t + dt*c6, y + dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5), ctx)
        k7 = f(t + dt*c7, y + dt * (a71 * k1 + a72 * k2 + a73 * k3 + a74 * k4 + a75 * k5 + a76 * k6), ctx)

        yNew = y + dt * (b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4 + b5 * k5 + b6 * k6)
        let error_y = dt * (bHat1 * k1 + bHat2 * k2 + bHat3 * k3 + bHat4 * k4 + bHat5 * k5 + bHat6 * k6 + bHat7 * k7)
        # error = calcError(y, yLow)
    result = (yNew, k7, dt, error)

@Saltsmart
Copy link
Author

Already Know:
alpha = c
beta = a
c_col = b

@psv4 psv4 linked a pull request Feb 3, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant