Skip to content

Commit

Permalink
Merge pull request rtqichen#126 from coldlbq/lbq
Browse files Browse the repository at this point in the history
turn the ode_demo.py to run on gpu(previously it can not run on gpu)
  • Loading branch information
rtqichen committed Oct 16, 2020
2 parents cd847cf + 3ef5ac3 commit 68b7ab0
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions examples/ode_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@

device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

true_y0 = torch.tensor([[2., 0.]])
t = torch.linspace(0., 25., args.data_size)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])
true_y0 = torch.tensor([[2., 0.]]).to(device)
t = torch.linspace(0., 25., args.data_size).to(device)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device)


class Lambda(nn.Module):
Expand All @@ -46,7 +46,7 @@ def get_batch():
batch_y0 = true_y[s] # (M, D)
batch_t = t[:args.batch_time] # (T)
batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0) # (T, M, D)
return batch_y0, batch_t, batch_y
return batch_y0.to(device), batch_t.to(device), batch_y.to(device)


def makedirs(dirname):
Expand All @@ -72,18 +72,18 @@ def visualize(true_y, pred_y, odefunc, itr):
ax_traj.set_title('Trajectories')
ax_traj.set_xlabel('t')
ax_traj.set_ylabel('x,y')
ax_traj.plot(t.numpy(), true_y.numpy()[:, 0, 0], t.numpy(), true_y.numpy()[:, 0, 1], 'g-')
ax_traj.plot(t.numpy(), pred_y.numpy()[:, 0, 0], '--', t.numpy(), pred_y.numpy()[:, 0, 1], 'b--')
ax_traj.set_xlim(t.min(), t.max())
ax_traj.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], 'g-')
ax_traj.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], '--', t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], 'b--')
ax_traj.set_xlim(t.cpu().min(), t.cpu().max())
ax_traj.set_ylim(-2, 2)
ax_traj.legend()

ax_phase.cla()
ax_phase.set_title('Phase Portrait')
ax_phase.set_xlabel('x')
ax_phase.set_ylabel('y')
ax_phase.plot(true_y.numpy()[:, 0, 0], true_y.numpy()[:, 0, 1], 'g-')
ax_phase.plot(pred_y.numpy()[:, 0, 0], pred_y.numpy()[:, 0, 1], 'b--')
ax_phase.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-')
ax_phase.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b--')
ax_phase.set_xlim(-2, 2)
ax_phase.set_ylim(-2, 2)

Expand All @@ -93,7 +93,7 @@ def visualize(true_y, pred_y, odefunc, itr):
ax_vecfield.set_ylabel('y')

y, x = np.mgrid[-2:2:21j, -2:2:21j]
dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2))).cpu().detach().numpy()
dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2)).to(device)).cpu().detach().numpy()
mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
dydt = (dydt / mag)
dydt = dydt.reshape(21, 21, 2)
Expand Down Expand Up @@ -151,17 +151,19 @@ def update(self, val):

ii = 0

func = ODEFunc()
func = ODEFunc().to(device)

optimizer = optim.RMSprop(func.parameters(), lr=1e-3)
end = time.time()

time_meter = RunningAverageMeter(0.97)

loss_meter = RunningAverageMeter(0.97)

for itr in range(1, args.niters + 1):
optimizer.zero_grad()
batch_y0, batch_t, batch_y = get_batch()
pred_y = odeint(func, batch_y0, batch_t)
pred_y = odeint(func, batch_y0, batch_t).to(device)
loss = torch.mean(torch.abs(pred_y - batch_y))
loss.backward()
optimizer.step()
Expand Down

0 comments on commit 68b7ab0

Please sign in to comment.