Skip to content

Commit

Permalink
Merge branch 'master' of github.com:datawhalechina/easy-rl
Browse files Browse the repository at this point in the history
  • Loading branch information
qiwang067 committed Sep 21, 2021
2 parents 1adddca + 1e60b68 commit 6ab9970
Show file tree
Hide file tree
Showing 10 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion codes/QLearning/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: johnjim0816@gmail.com
Date: 2020-09-11 23:03:00
LastEditor: John
LastEditTime: 2021-09-15 13:18:37
LastEditTime: 2021-09-19 23:05:45
Discription: use defaultdict to define Q table
Environment:
'''
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 7 additions & 7 deletions codes/QLearning/task0_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: johnjim0816@gmail.com
Date: 2020-09-11 23:03:00
LastEditor: John
LastEditTime: 2021-09-15 14:44:25
LastEditTime: 2021-09-20 00:32:59
Discription:
Environment:
'''
Expand All @@ -31,13 +31,13 @@ def __init__(self):
self.env = 'CliffWalking-v0' # 环境名称
self.result_path = curr_path+"/outputs/" +self.env+'/'+curr_time+'/results/' # 保存结果的路径
self.model_path = curr_path+"/outputs/" +self.env+'/'+curr_time+'/models/' # 保存模型的路径
self.train_eps = 200 # 训练的回合数
self.train_eps = 400 # 训练的回合数
self.eval_eps = 30 # 测试的回合数
self.gamma = 0.9 # reward的衰减率
self.epsilon_start = 0.90 # e-greedy策略中初始epsilon
self.epsilon_start = 0.99 # e-greedy策略中初始epsilon
self.epsilon_end = 0.01 # e-greedy策略中的终止epsilon
self.epsilon_decay = 200 # e-greedy策略中epsilon的衰减率
self.lr = 0.05 # 学习率
self.epsilon_decay = 300 # e-greedy策略中epsilon的衰减率
self.lr = 0.1 # 学习率
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检测GPU


Expand Down Expand Up @@ -111,8 +111,8 @@ def eval(cfg,env,agent):
plot_rewards_cn(rewards,ma_rewards,tag="train",env=cfg.env,algo = cfg.algo,path=cfg.result_path)

# # 测试
# env,agent = env_agent_config(cfg,seed=10)
# agent.load(path=cfg.model_path) # 加载模型
env,agent = env_agent_config(cfg,seed=10)
agent.load(path=cfg.model_path) # 加载模型
rewards,ma_rewards = eval(cfg,env,agent)
save_results(rewards,ma_rewards,tag='eval',path=cfg.result_path)
plot_rewards_cn(rewards,ma_rewards,tag="eval",env=cfg.env,algo = cfg.algo,path=cfg.result_path)
Expand Down
5 changes: 3 additions & 2 deletions codes/common/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: johnjim0816@gmail.com
Date: 2020-10-07 20:57:11
LastEditor: John
LastEditTime: 2021-09-15 14:56:15
LastEditTime: 2021-09-19 23:00:36
Discription:
Environment:
'''
Expand All @@ -29,14 +29,15 @@ def plot_rewards_cn(rewards,ma_rewards,tag="train",env='CartPole-v0',algo = "DQN
''' 中文画图
'''
sns.set()
plt.figure()
plt.title(u"{}环境下{}算法的学习曲线".format(env,algo),fontproperties=chinese_font())
plt.xlabel(u'回合数',fontproperties=chinese_font())
plt.plot(rewards)
plt.plot(ma_rewards)
plt.legend((u'奖励',u'滑动平均奖励',),loc="best",prop=chinese_font())
if save:
plt.savefig(path+f"{tag}_rewards_curve_cn")
plt.show()
# plt.show()

def plot_losses(losses,algo = "DQN",save=True,path='./'):
sns.set()
Expand Down

0 comments on commit 6ab9970

Please sign in to comment.