forked from datawhalechina/easy-rl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
75df999
commit 41fb561
Showing
75 changed files
with
1,248 additions
and
918 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torch.autograd as autograd | ||
import random | ||
import math | ||
class CNN(nn.Module): | ||
def __init__(self, input_dim, output_dim): | ||
super(CNN, self).__init__() | ||
|
||
self.input_dim = input_dim | ||
self.output_dim = output_dim | ||
|
||
self.features = nn.Sequential( | ||
nn.Conv2d(input_dim[0], 32, kernel_size=8, stride=4), | ||
nn.ReLU(), | ||
nn.Conv2d(32, 64, kernel_size=4, stride=2), | ||
nn.ReLU(), | ||
nn.Conv2d(64, 64, kernel_size=3, stride=1), | ||
nn.ReLU() | ||
) | ||
|
||
self.fc = nn.Sequential( | ||
nn.Linear(self.feature_size(), 512), | ||
nn.ReLU(), | ||
nn.Linear(512, self.output_dim) | ||
) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = x.view(x.size(0), -1) | ||
x = self.fc(x) | ||
return x | ||
|
||
def feature_size(self): | ||
return self.features(autograd.Variable(torch.zeros(1, *self.input_dim))).view(1, -1).size(1) | ||
|
||
|
||
def act(self, state, epsilon): | ||
if random.random() > epsilon: | ||
state = Variable(torch.FloatTensor(np.float32(state)).unsqueeze(0), volatile=True) | ||
q_value = self.forward(state) | ||
action = q_value.max(1)[1].data[0] | ||
else: | ||
action = random.randrange(env.action_space.n) | ||
return action | ||
|
||
class ReplayBuffer: | ||
def __init__(self, capacity): | ||
self.capacity = capacity # 经验回放的容量 | ||
self.buffer = [] # 缓冲区 | ||
self.position = 0 | ||
|
||
def push(self, state, action, reward, next_state, done): | ||
''' 缓冲区是一个队列,容量超出时去掉开始存入的转移(transition) | ||
''' | ||
if len(self.buffer) < self.capacity: | ||
self.buffer.append(None) | ||
self.buffer[self.position] = (state, action, reward, next_state, done) | ||
self.position = (self.position + 1) % self.capacity | ||
|
||
def sample(self, batch_size): | ||
batch = random.sample(self.buffer, batch_size) # 随机采出小批量转移 | ||
state, action, reward, next_state, done = zip(*batch) # 解压成状态,动作等 | ||
return state, action, reward, next_state, done | ||
|
||
def __len__(self): | ||
''' 返回当前存储的量 | ||
''' | ||
return len(self.buffer) | ||
|
||
class DQN: | ||
def __init__(self, n_states, n_actions, cfg): | ||
|
||
self.n_actions = n_actions # 总的动作个数 | ||
self.device = cfg.device # 设备,cpu或gpu等 | ||
self.gamma = cfg.gamma # 奖励的折扣因子 | ||
# e-greedy策略相关参数 | ||
self.frame_idx = 0 # 用于epsilon的衰减计数 | ||
self.epsilon = lambda frame_idx: cfg.epsilon_end + \ | ||
(cfg.epsilon_start - cfg.epsilon_end) * \ | ||
math.exp(-1. * frame_idx / cfg.epsilon_decay) | ||
self.batch_size = cfg.batch_size | ||
self.policy_net = CNN(n_states, n_actions).to(self.device) | ||
self.target_net = CNN(n_states, n_actions).to(self.device) | ||
for target_param, param in zip(self.target_net.parameters(),self.policy_net.parameters()): # 复制参数到目标网路targe_net | ||
target_param.data.copy_(param.data) | ||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=cfg.lr) # 优化器 | ||
self.memory = ReplayBuffer(cfg.memory_capacity) # 经验回放 | ||
|
||
def choose_action(self, state): | ||
''' 选择动作 | ||
''' | ||
self.frame_idx += 1 | ||
if random.random() > self.epsilon(self.frame_idx): | ||
with torch.no_grad(): | ||
state = torch.tensor([state], device=self.device, dtype=torch.float32) | ||
q_values = self.policy_net(state) | ||
action = q_values.max(1)[1].item() # 选择Q值最大的动作 | ||
else: | ||
action = random.randrange(self.n_actions) | ||
return action | ||
def update(self): | ||
if len(self.memory) < self.batch_size: # 当memory中不满足一个批量时,不更新策略 | ||
return | ||
# 从经验回放中(replay memory)中随机采样一个批量的转移(transition) | ||
state_batch, action_batch, reward_batch, next_state_batch, done_batch = self.memory.sample( | ||
self.batch_size) | ||
# 转为张量 | ||
state_batch = torch.tensor(state_batch, device=self.device, dtype=torch.float) | ||
action_batch = torch.tensor(action_batch, device=self.device).unsqueeze(1) | ||
reward_batch = torch.tensor(reward_batch, device=self.device, dtype=torch.float) | ||
next_state_batch = torch.tensor(next_state_batch, device=self.device, dtype=torch.float) | ||
done_batch = torch.tensor(np.float32(done_batch), device=self.device) | ||
q_values = self.policy_net(state_batch).gather(dim=1, index=action_batch) # 计算当前状态(s_t,a)对应的Q(s_t, a) | ||
next_q_values = self.target_net(next_state_batch).max(1)[0].detach() # 计算下一时刻的状态(s_t_,a)对应的Q值 | ||
# 计算期望的Q值,对于终止状态,此时done_batch[0]=1, 对应的expected_q_value等于reward | ||
expected_q_values = reward_batch + self.gamma * next_q_values * (1-done_batch) | ||
loss = nn.MSELoss()(q_values, expected_q_values.unsqueeze(1)) # 计算均方根损失 | ||
# 优化更新模型 | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
for param in self.policy_net.parameters(): # clip防止梯度爆炸 | ||
param.grad.data.clamp_(-1, 1) | ||
self.optimizer.step() | ||
|
||
def save(self, path): | ||
torch.save(self.target_net.state_dict(), path+'dqn_checkpoint.pth') | ||
|
||
def load(self, path): | ||
self.target_net.load_state_dict(torch.load(path+'dqn_checkpoint.pth')) | ||
for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()): | ||
param.data.copy_(target_param.data) |
Oops, something went wrong.