Skip to content

Commit

Permalink
fix loading, adjust start_timesteps, add humanoid
Browse files Browse the repository at this point in the history
  • Loading branch information
sfujim committed Feb 13, 2020
1 parent f6cca9e commit ade6260
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 15 deletions.
6 changes: 5 additions & 1 deletion DDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,17 @@ def train(self, replay_buffer, batch_size=64):
def save(self, filename):
torch.save(self.critic.state_dict(), filename + "_critic")
torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")

torch.save(self.actor.state_dict(), filename + "_actor")
torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")


def load(self, filename):
self.critic.load_state_dict(torch.load(filename + "_critic"))
self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
self.critic_target = copy.deepcopy(self.critic)

self.actor.load_state_dict(torch.load(filename + "_actor"))
self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))

self.actor_target = copy.deepcopy(self.actor)

6 changes: 5 additions & 1 deletion OurDDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,17 @@ def train(self, replay_buffer, batch_size=100):
def save(self, filename):
torch.save(self.critic.state_dict(), filename + "_critic")
torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")

torch.save(self.actor.state_dict(), filename + "_actor")
torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")


def load(self, filename):
self.critic.load_state_dict(torch.load(filename + "_critic"))
self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
self.critic_target = copy.deepcopy(self.critic)

self.actor.load_state_dict(torch.load(filename + "_actor"))
self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))

self.actor_target = copy.deepcopy(self.actor)

5 changes: 5 additions & 0 deletions TD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,17 @@ def train(self, replay_buffer, batch_size=100):
def save(self, filename):
torch.save(self.critic.state_dict(), filename + "_critic")
torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")

torch.save(self.actor.state_dict(), filename + "_actor")
torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")


def load(self, filename):
self.critic.load_state_dict(torch.load(filename + "_critic"))
self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
self.critic_target = copy.deepcopy(self.critic)

self.actor.load_state_dict(torch.load(filename + "_actor"))
self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
self.actor_target = copy.deepcopy(self.actor)

2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def eval_policy(policy, env_name, seed, eval_episodes=10):
parser.add_argument("--policy", default="TD3") # Policy name (TD3, DDPG or OurDDPG)
parser.add_argument("--env", default="HalfCheetah-v2") # OpenAI gym environment name
parser.add_argument("--seed", default=0, type=int) # Sets Gym, PyTorch and Numpy seeds
parser.add_argument("--start_timesteps", default=1e4, type=int) # Time steps initial random policy is used
parser.add_argument("--start_timesteps", default=25e3, type=int)# Time steps initial random policy is used
parser.add_argument("--eval_freq", default=5e3, type=int) # How often (time steps) we evaluate
parser.add_argument("--max_timesteps", default=1e6, type=int) # Max time steps to run environment
parser.add_argument("--expl_noise", default=0.1) # Std of Gaussian exploration noise
Expand Down
25 changes: 13 additions & 12 deletions run_experiments.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,28 @@ for ((i=0;i<10;i+=1))
do
python main.py \
--policy "TD3" \
--env "HalfCheetah-v2" \
--seed $i \
--start_timesteps 10000
--env "HalfCheetah-v3" \
--seed $i

python main.py \
--policy "TD3" \
--env "Hopper-v2" \
--seed $i \
--start_timesteps 1000
--env "Hopper-v3" \
--seed $i

python main.py \
--policy "TD3" \
--env "Walker2d-v2" \
--seed $i \
--start_timesteps 1000
--env "Walker2d-v3" \
--seed $i

python main.py \
--policy "TD3" \
--env "Ant-v2" \
--seed $i \
--start_timesteps 10000
--env "Ant-v3" \
--seed $i

python main.py \
--policy "TD3" \
--env "Humanoid-v3" \
--seed $i

python main.py \
--policy "TD3" \
Expand Down

0 comments on commit ade6260

Please sign in to comment.