Skip to content

Commit

Permalink
change
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxiaochang committed Oct 31, 2020
1 parent 430982b commit ab9fcd9
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test_virtual/dqn-overtaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _on_step(self) -> bool:


# create log dir
log_dir = "DqnOvertaking_tmp/"
log_dir = "test_results/DqnOvertaking/" + TIMESTAMP
os.makedirs(log_dir, exist_ok=True)

# Create and wrap the environment
Expand All @@ -79,11 +79,11 @@ def _on_step(self) -> bool:
# Instantiate the agent
# model = DQN(EgoAttentionNetwork, env, learning_rate=1e-3, prioritized_replay=True, verbose=1)
model = DQN("MlpPolicy", env, learning_rate=1e-3, prioritized_replay=True, verbose=1,
tensorboard_log="./DQN_overtaking_tensorboard/")
tensorboard_log="./test_results/DQN_overtaking_tensorboard/" + TIMESTAMP)
# create the callback: check every 1000 steps
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
# Train the agent
time_steps = 1e5
time_steps = 1000
model.learn(total_timesteps=int(time_steps), callback=callback)

results_plotter.plot_results([log_dir], time_steps, results_plotter.X_TIMESTEPS, "DQN OvertakingEnv")
Expand Down

0 comments on commit ab9fcd9

Please sign in to comment.