Skip to content

Commit

Permalink
upd examples and readme
Browse files Browse the repository at this point in the history
  • Loading branch information
rtqichen committed Nov 16, 2018
1 parent e467dbb commit a9b9a01
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pip install -e .
## Examples
Examples are placed in the [`examples`](./examples) directory.

We encourage those who are interested in using this library to take a look at `examples/ode_demo.py` for understanding how to use `torchdiffeq` to fit a simple spiral ODE.
We encourage those who are interested in using this library to take a look at [`examples/ode_demo.py`](./examples/ode_demo.py) for understanding how to use `torchdiffeq` to fit a simple spiral ODE.

<p align="center">
<img align="middle" src="./assets/ode_demo.gif" alt="ODE Demo" width="500" height="250" />
Expand Down Expand Up @@ -74,3 +74,15 @@ Fixed-step:

### References
[1] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." *Advances in Neural Processing Information Systems.* 2018.

---

If you found this library useful in your research, please consider citing
```
@article{chen2018neural,
title={Neural Ordinary Differential Equations},
author={Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David},
journal={Advances in Neural Information Processing Systems},
year={2018}
}
```
6 changes: 3 additions & 3 deletions examples/ode_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.optim as optim

parser = argparse.ArgumentParser('ODE demo')
parser.add_argument('--method', type=str, default='dopri5')
parser.add_argument('--method', type=str, choices=['dopri5', 'adams'], default='dopri5')
parser.add_argument('--data_size', type=int, default=1000)
parser.add_argument('--batch_time', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=20)
Expand Down Expand Up @@ -38,7 +38,7 @@ def forward(self, t, y):


with torch.no_grad():
true_y = odeint(Lambda(), true_y0, t)
true_y = odeint(Lambda(), true_y0, t, method='dopri5')


def get_batch():
Expand Down Expand Up @@ -152,7 +152,7 @@ def update(self, val):
ii = 0

func = ODEFunc()
optimizer = optim.Adam(func.parameters(), lr=1e-3)
optimizer = optim.RMSprop(func.parameters(), lr=1e-3)
end = time.time()

time_meter = RunningAverageMeter(0.97)
Expand Down

0 comments on commit a9b9a01

Please sign in to comment.