Skip to content

Commit

Permalink
Add unit tests for callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
slishak committed Jul 27, 2022
1 parent 72e3f8c commit ef2540b
Showing 1 changed file with 119 additions and 1 deletion.
120 changes: 119 additions & 1 deletion tests/odeint_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torchdiffeq

from problems import (construct_problem, PROBLEMS, DTYPES, DEVICES, METHODS, ADAPTIVE_METHODS, FIXED_METHODS)
from problems import (construct_problem, PROBLEMS, DTYPES, DEVICES, METHODS, ADAPTIVE_METHODS, FIXED_METHODS, SCIPY_METHODS)


def rel_error(true, estimate):
Expand Down Expand Up @@ -239,5 +239,123 @@ def test_min_max_step(self):
self.assertGreater(f.nfe, 100)


class _NeuralF(torch.nn.Module):
def __init__(self, width, oscillate):
super(_NeuralF, self).__init__()
self.linears = torch.nn.Sequential(torch.nn.Linear(2, width),
torch.nn.Tanh(),
torch.nn.Linear(width, 2),
torch.nn.Tanh())
self.nfe = 0
self.oscillate = oscillate

def forward(self, t, x):
self.nfe += 1
out = self.linears(x)
if self.oscillate:
out = out * t.mul(20).sin()
return out


class TestCallbacks(unittest.TestCase):
def test_wrong_callback(self):
x0 = torch.tensor([1.0, 2.0])
t = torch.tensor([0., 1.0])

for method in FIXED_METHODS:
for callback_name in ('callback_accept_step', 'callback_reject_step'):
with self.subTest(method=method):
f = _NeuralF(width=10, oscillate=False)
setattr(f, callback_name, lambda t0, y0, dt: None)
with self.assertWarns(Warning):
torchdiffeq.odeint(f, x0, t, method=method)

for method in SCIPY_METHODS:
for callback_name in ('callback_step', 'callback_accept_step', 'callback_reject_step'):
with self.subTest(method=method):
f = _NeuralF(width=10, oscillate=False)
setattr(f, callback_name, lambda t0, y0, dt: None)
with self.assertWarns(Warning):
torchdiffeq.odeint(f, x0, t, method=method)

def test_steps(self):
for forward, adjoint in ((False, True), (True, False), (True, True)):
for method in FIXED_METHODS + ADAPTIVE_METHODS:
if method == 'dopri8': # using torch.float32
continue
with self.subTest(forward=forward, adjoint=adjoint, method=method):

f = _NeuralF(width=10, oscillate=False)

if forward:
forward_counter = 0
forward_accept_counter = 0
forward_reject_counter = 0

def callback_step(t0, y0, dt):
nonlocal forward_counter
forward_counter += 1

def callback_accept_step(t0, y0, dt):
nonlocal forward_accept_counter
forward_accept_counter += 1

def callback_reject_step(t0, y0, dt):
nonlocal forward_reject_counter
forward_reject_counter += 1

f.callback_step = callback_step
if method in ADAPTIVE_METHODS:
f.callback_accept_step = callback_accept_step
f.callback_reject_step = callback_reject_step

if adjoint:
adjoint_counter = 0
adjoint_accept_counter = 0
adjoint_reject_counter = 0

def callback_step_adjoint(t0, y0, dt):
nonlocal adjoint_counter
adjoint_counter += 1

def callback_accept_step_adjoint(t0, y0, dt):
nonlocal adjoint_accept_counter
adjoint_accept_counter += 1

def callback_reject_step_adjoint(t0, y0, dt):
nonlocal adjoint_reject_counter
adjoint_reject_counter += 1

f.callback_step_adjoint = callback_step_adjoint
if method in ADAPTIVE_METHODS:
f.callback_accept_step_adjoint = callback_accept_step_adjoint
f.callback_reject_step_adjoint = callback_reject_step_adjoint

x0 = torch.tensor([1.0, 2.0])
t = torch.tensor([0., 1.0])

if method in FIXED_METHODS:
kwargs = dict(options=dict(step_size=0.1))
elif method == 'implicit_adams':
kwargs = dict(rtol=1e-3, atol=1e-4)
else:
kwargs = {}
xs = torchdiffeq.odeint_adjoint(f, x0, t, method=method, **kwargs)

if forward:
if method in FIXED_METHODS:
self.assertEqual(forward_counter, 10)
if method in ADAPTIVE_METHODS:
self.assertGreater(forward_counter, 0)
self.assertEqual(forward_accept_counter + forward_reject_counter, forward_counter)
if adjoint:
xs.sum().backward()
if method in FIXED_METHODS:
self.assertEqual(adjoint_counter, 10)
if method in ADAPTIVE_METHODS:
self.assertGreater(adjoint_counter, 0)
self.assertEqual(adjoint_accept_counter + adjoint_reject_counter, adjoint_counter)


if __name__ == '__main__':
unittest.main()

0 comments on commit ef2540b

Please sign in to comment.