Skip to content

Commit

Permalink
Fix the fix of rtqichen#180
Browse files Browse the repository at this point in the history
Restores the backward pass of odeint_adjoint in event mode
  • Loading branch information
shivak committed Sep 20, 2021
1 parent ca85e83 commit 21c1e49
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchdiffeq/_impl/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def backward(ctx, *grad_y):
# Does NOT backpropagate through the event time.
event_mode = ctx.event_mode
if event_mode:
t, y, event_t, *adjoint_params = ctx.saved_params
t, y, event_t, *adjoint_params = ctx.saved_tensors
_t = t
t = torch.cat([t[0].reshape(-1), event_t.reshape(-1)])
grad_y = grad_y[1]
Expand Down

0 comments on commit 21c1e49

Please sign in to comment.