Skip to content

Commit

Permalink
Fixes error when using pytorch1.7+ with GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
rtqichen committed Jun 2, 2021
1 parent 203999e commit 25fd9f2
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions torchdiffeq/_impl/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,6 @@ class Perturb(Enum):


class _PerturbFunc(torch.nn.Module):
_inf = torch.tensor(math.inf)
_neginf = torch.tensor(-math.inf)

def __init__(self, base_func):
super(_PerturbFunc, self).__init__()
Expand All @@ -181,10 +179,10 @@ def forward(self, t, y, *, perturb=Perturb.NONE):
t = t.to(y.dtype)
if perturb is Perturb.NEXT:
# Replace with next smallest representable value.
t = _nextafter(t, self._inf)
t = _nextafter(t, t + 1)
elif perturb is Perturb.PREV:
# Replace with prev largest representable value.
t = _nextafter(t, self._neginf)
t = _nextafter(t, t - 1)
else:
# Do nothing.
pass
Expand Down

0 comments on commit 25fd9f2

Please sign in to comment.