Skip to content

Commit

Permalink
implement Heun order 3
Browse files Browse the repository at this point in the history
  • Loading branch information
rtqichen committed Aug 30, 2023
1 parent 7265eb7 commit 84e220a
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tests/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def y_exact(self, t):
DEVICES = ['cpu']
if torch.cuda.is_available():
DEVICES.append('cuda')
FIXED_METHODS = ('euler', 'midpoint', 'rk4', 'explicit_adams', 'implicit_adams')
FIXED_METHODS = ('euler', 'midpoint', 'heun3', 'rk4', 'explicit_adams', 'implicit_adams')
ADAMS_METHODS = ('explicit_adams', 'implicit_adams')
ADAPTIVE_METHODS = ('adaptive_heun', 'fehlberg2', 'bosh3', 'dopri5', 'dopri8')
SCIPY_METHODS = ('scipy_solver',)
Expand Down
18 changes: 17 additions & 1 deletion torchdiffeq/_impl/fixed_grid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .solvers import FixedGridODESolver
from .rk_common import rk4_alt_step_func
from .rk_common import rk4_alt_step_func, rk3_step_func
from .misc import Perturb


Expand Down Expand Up @@ -27,3 +27,19 @@ class RK4(FixedGridODESolver):
def _step_func(self, func, t0, dt, t1, y0):
f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)
return rk4_alt_step_func(func, t0, dt, t1, y0, f0=f0, perturb=self.perturb), f0


class Heun3(FixedGridODESolver):
order = 3

def _step_func(self, func, t0, dt, t1, y0):
f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)

butcher_tableu = [
[0.0, 0.0, 0.0, 0.0],
[1/3, 1/3, 0.0, 0.0],
[2/3, 0.0, 2/3, 0.0],
[0.0, 1/4, 0.0, 3/4],
]

return rk3_step_func(func, t0, dt, t1, y0, butcher_tableu=butcher_tableu, f0=f0, perturb=self.perturb), f0
3 changes: 2 additions & 1 deletion torchdiffeq/_impl/odeint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .bosh3 import Bosh3Solver
from .adaptive_heun import AdaptiveHeunSolver
from .fehlberg2 import Fehlberg2
from .fixed_grid import Euler, Midpoint, RK4
from .fixed_grid import Euler, Midpoint, Heun3, RK4
from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton
from .dopri8 import Dopri8Solver
from .scipy_wrapper import ScipyWrapperODESolver
Expand All @@ -18,6 +18,7 @@
'adaptive_heun': AdaptiveHeunSolver,
'euler': Euler,
'midpoint': Midpoint,
'heun3': Heun3,
'rk4': RK4,
'explicit_adams': AdamsBashforth,
'implicit_adams': AdamsBashforthMoulton,
Expand Down
21 changes: 21 additions & 0 deletions torchdiffeq/_impl/rk_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,27 @@ def rk4_alt_step_func(func, t0, dt, t1, y0, f0=None, perturb=False):
return (k1 + 3 * (k2 + k3) + k4) * dt * 0.125


def rk3_step_func(func, t0, dt, t1, y0, butcher_tableu=None, f0=None, perturb=False):
"""butcher_tableu should be of the form
[
[0 , 0 , 0 , 0],
[c_2, a_{21}, 0 , 0],
[c_3, a_{31}, a_{32}, 0],
[0 , b_1 , b_2 , b_3],
]
https://en.wikipedia.org/wiki/List_of_Runge-Kutta_methods
"""
k1 = f0
if k1 is None:
k1 = func(t0, y0, perturb=Perturb.NEXT if perturb else Perturb.NONE)

k2 = func(t0 + dt * butcher_tableu[1][0], y0 + dt * k1 * butcher_tableu[1][1])
k3 = func(t0 + dt * butcher_tableu[2][0], y0 + dt * (k1 * butcher_tableu[2][1] + k2 * butcher_tableu[2][2]))
return dt * (k1 * butcher_tableu[3][1] + k2 * butcher_tableu[3][2] + k3 * butcher_tableu[3][3])


class RKAdaptiveStepsizeODESolver(AdaptiveStepsizeEventODESolver):
order: int
tableau: _ButcherTableau
Expand Down

0 comments on commit 84e220a

Please sign in to comment.