Skip to content

Commit

Permalink
Camera-ready cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sfujim committed Jun 9, 2018
1 parent 00ae5cd commit bc98e62
Show file tree
Hide file tree
Showing 44 changed files with 145 additions and 10 deletions.
4 changes: 2 additions & 2 deletions DDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def train(self, replay_buffer, iterations, batch_size=64, discount=0.99, tau=0.0

# Q target = reward + discount * Q(next_state, pi(next_state))
target_Q = self.critic_target(next_state, self.actor_target(next_state))
target_Q.volatile = False
target_Q = reward + (done * discount * target_Q)
target_Q.volatile = False

# Get current Q estimate
current_Q = self.critic(state, action)
Expand All @@ -120,7 +120,7 @@ def train(self, replay_buffer, iterations, batch_size=64, discount=0.99, tau=0.0
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

for param, target_param, in zip(self.actor.parameters(), self.actor_target.parameters()):
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)


Expand Down
133 changes: 133 additions & 0 deletions OurDDPG.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

import utils


# Re-tuned version of Deep Deterministic Policy Gradients (DDPG)
# Paper: https://arxiv.org/abs/1509.02971


def var(tensor, volatile=False):
if torch.cuda.is_available():
return Variable(tensor, volatile=volatile).cuda()
else:
return Variable(tensor, volatile=volatile)


class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()

self.l1 = nn.Linear(state_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, action_dim)

self.max_action = max_action


def forward(self, x):
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = self.max_action * F.tanh(self.l3(x))
return x


class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()

self.l1 = nn.Linear(state_dim + action_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, 1)


def forward(self, x, u):
x = F.relu(self.l1(torch.cat([x, u], 1)))
x = F.relu(self.l2(x))
x = self.l3(x)
return x


class DDPG(object):
def __init__(self, state_dim, action_dim, max_action):
self.actor = Actor(state_dim, action_dim, max_action)
self.actor_target = Actor(state_dim, action_dim, max_action)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = torch.optim.Adam(self.actor.parameters())

self.critic = Critic(state_dim, action_dim)
self.critic_target = Critic(state_dim, action_dim)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = torch.optim.Adam(self.critic.parameters())

if torch.cuda.is_available():
self.actor = self.actor.cuda()
self.actor_target = self.actor_target.cuda()
self.critic = self.critic.cuda()
self.critic_target = self.critic_target.cuda()

self.criterion = nn.MSELoss()
self.state_dim = state_dim


def select_action(self, state):
state = var(torch.FloatTensor(state.reshape(-1, self.state_dim)), volatile=True)
return self.actor(state).cpu().data.numpy().flatten()


def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.005):

for it in range(iterations):

# Sample replay buffer
x, y, u, r, d = replay_buffer.sample(batch_size)
state = var(torch.FloatTensor(x))
action = var(torch.FloatTensor(u))
next_state = var(torch.FloatTensor(y), volatile=True)
done = var(torch.FloatTensor(1 - d))
reward = var(torch.FloatTensor(r))

# Q target = reward + discount * Q(next_state, pi(next_state))
target_Q = self.critic_target(next_state, self.actor_target(next_state))
target_Q = reward + (done * discount * target_Q)
target_Q.volatile = False

# Get current Q estimate
current_Q = self.critic(state, action)

# Compute critic loss
critic_loss = self.criterion(current_Q, target_Q)

# Optimize the critic
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()

# Compute actor loss
actor_loss = -self.critic(state, self.actor(state)).mean()

# Optimize the actor
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()

# Update the frozen target models
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)


def save(self, filename, directory):
torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename))
torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename))


def load(self, filename, directory):
self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, filename)))
self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename)))
6 changes: 3 additions & 3 deletions TD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.
next_action = self.actor_target(next_state) + var(torch.FloatTensor(noise))
next_action = next_action.clamp(-self.max_action, self.max_action)

# Q target = reward + discount * min(Qi(next_state, pi(next_state)))
# Q target = reward + discount * min_i(Qi(next_state, pi(next_state)))
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
target_Q = torch.min(torch.cat([target_Q1, target_Q2], 1), 1)[0].view(-1, 1)
target_Q.volatile = False
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + (done * discount * target_Q)
target_Q.volatile = False

# Get current Q estimates
current_Q1, current_Q2 = self.critic(state, action)
Expand Down
Binary file modified learning_curves/Ant/TD3_Ant-v1_0.npy
Binary file not shown.
Binary file modified learning_curves/Ant/TD3_Ant-v1_1.npy
Binary file not shown.
Binary file modified learning_curves/Ant/TD3_Ant-v1_2.npy
Binary file not shown.
Binary file modified learning_curves/Ant/TD3_Ant-v1_3.npy
Binary file not shown.
Binary file modified learning_curves/Ant/TD3_Ant-v1_4.npy
Binary file not shown.
Binary file modified learning_curves/Ant/TD3_Ant-v1_5.npy
Binary file not shown.
Binary file modified learning_curves/Ant/TD3_Ant-v1_6.npy
Binary file not shown.
Binary file modified learning_curves/Ant/TD3_Ant-v1_7.npy
Binary file not shown.
Binary file modified learning_curves/Ant/TD3_Ant-v1_8.npy
Binary file not shown.
Binary file modified learning_curves/Ant/TD3_Ant-v1_9.npy
Binary file not shown.
Binary file modified learning_curves/HalfCheetah/TD3_HalfCheetah-v1_0.npy
Binary file not shown.
Binary file modified learning_curves/HalfCheetah/TD3_HalfCheetah-v1_1.npy
Binary file not shown.
Binary file modified learning_curves/HalfCheetah/TD3_HalfCheetah-v1_2.npy
Binary file not shown.
Binary file modified learning_curves/HalfCheetah/TD3_HalfCheetah-v1_3.npy
Binary file not shown.
Binary file modified learning_curves/HalfCheetah/TD3_HalfCheetah-v1_4.npy
Binary file not shown.
Binary file modified learning_curves/HalfCheetah/TD3_HalfCheetah-v1_5.npy
Binary file not shown.
Binary file modified learning_curves/HalfCheetah/TD3_HalfCheetah-v1_6.npy
Binary file not shown.
Binary file modified learning_curves/HalfCheetah/TD3_HalfCheetah-v1_7.npy
Binary file not shown.
Binary file modified learning_curves/HalfCheetah/TD3_HalfCheetah-v1_8.npy
Binary file not shown.
Binary file modified learning_curves/HalfCheetah/TD3_HalfCheetah-v1_9.npy
Binary file not shown.
Binary file modified learning_curves/Hopper/TD3_Hopper-v1_0.npy
Binary file not shown.
Binary file modified learning_curves/Hopper/TD3_Hopper-v1_1.npy
Binary file not shown.
Binary file modified learning_curves/Hopper/TD3_Hopper-v1_2.npy
Binary file not shown.
Binary file modified learning_curves/Hopper/TD3_Hopper-v1_3.npy
Binary file not shown.
Binary file modified learning_curves/Hopper/TD3_Hopper-v1_4.npy
Binary file not shown.
Binary file modified learning_curves/Hopper/TD3_Hopper-v1_5.npy
Binary file not shown.
Binary file modified learning_curves/Hopper/TD3_Hopper-v1_6.npy
Binary file not shown.
Binary file modified learning_curves/Hopper/TD3_Hopper-v1_7.npy
Binary file not shown.
Binary file modified learning_curves/Hopper/TD3_Hopper-v1_8.npy
Binary file not shown.
Binary file modified learning_curves/Hopper/TD3_Hopper-v1_9.npy
Binary file not shown.
Binary file modified learning_curves/Walker/TD3_Walker2d-v1_0.npy
Binary file not shown.
Binary file modified learning_curves/Walker/TD3_Walker2d-v1_1.npy
Binary file not shown.
Binary file modified learning_curves/Walker/TD3_Walker2d-v1_2.npy
Binary file not shown.
Binary file modified learning_curves/Walker/TD3_Walker2d-v1_3.npy
Binary file not shown.
Binary file modified learning_curves/Walker/TD3_Walker2d-v1_4.npy
Binary file not shown.
Binary file modified learning_curves/Walker/TD3_Walker2d-v1_5.npy
Binary file not shown.
Binary file modified learning_curves/Walker/TD3_Walker2d-v1_6.npy
Binary file not shown.
Binary file modified learning_curves/Walker/TD3_Walker2d-v1_7.npy
Binary file not shown.
Binary file modified learning_curves/Walker/TD3_Walker2d-v1_8.npy
Binary file not shown.
Binary file modified learning_curves/Walker/TD3_Walker2d-v1_9.npy
Binary file not shown.
12 changes: 7 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import os

import utils
import DDPG
import TD3
import OurDDPG
import DDPG


# Runs policy for X episodes and returns average reward
Expand All @@ -16,7 +17,7 @@ def evaluate_policy(policy, eval_episodes=10):
obs = env.reset()
done = False
while not done:
action = policy.select_action(np.array(obs)).clip(env.action_space.low, env.action_space.high)
action = policy.select_action(np.array(obs))
obs, reward, done, _ = env.step(action)
avg_reward += reward

Expand Down Expand Up @@ -69,8 +70,9 @@ def evaluate_policy(policy, eval_episodes=10):
max_action = int(env.action_space.high[0])

# Initialize policy
if args.policy_name == "DDPG": policy = DDPG.DDPG(state_dim, action_dim, max_action)
elif args.policy_name == "TD3": policy = TD3.TD3(state_dim, action_dim, max_action)
if args.policy_name == "TD3": policy = TD3.TD3(state_dim, action_dim, max_action)
elif args.policy_name == "OurDDPG": policy = OurDDPG.DDPG(state_dim, action_dim, max_action)
elif args.policy_name == "DDPG": policy = DDPG.DDPG(state_dim, action_dim, max_action)

replay_buffer = utils.ReplayBuffer()

Expand Down Expand Up @@ -98,7 +100,7 @@ def evaluate_policy(policy, eval_episodes=10):
timesteps_since_eval %= args.eval_freq
evaluations.append(evaluate_policy(policy))

if args.save_models: policy.save("%s" % (file_name), directory="./pytorch_models")
if args.save_models: policy.save(file_name, directory="./pytorch_models")
np.save("./results/%s" % (file_name), evaluations)

# Reset environment
Expand Down

0 comments on commit bc98e62

Please sign in to comment.