Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two RL examples and examples refactoring #211

Merged
merged 4 commits into from
Jul 16, 2018
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[WIP] Actor-critic example
  • Loading branch information
vfdev-5 committed Jul 4, 2018
commit 82ebef84deae123f0ece6264718b7455ae5d9cd2
153 changes: 153 additions & 0 deletions examples/actor_critic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import argparse
from itertools import count
from collections import namedtuple

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

try:
import gym
except ImportError:
raise RuntimeError("Please install opengym: pip install gym")


from ignite.engine import Engine, Events


SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])


class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
self.affine1 = nn.Linear(4, 128)
self.action_head = nn.Linear(128, 2)
self.value_head = nn.Linear(128, 1)

self.saved_actions = []
self.rewards = []

def forward(self, x):
x = F.relu(self.affine1(x))
action_scores = self.action_head(x)
state_values = self.value_head(x)
return F.softmax(action_scores, dim=-1), state_values


def select_action(model, observation):
observation = torch.from_numpy(observation).float()
probs, observation_value = model(observation)
m = Categorical(probs)
action = m.sample()
model.saved_actions.append(SavedAction(m.log_prob(action), observation_value))
return action.item()


def finish_episode(model, optimizer, gamma, eps):
R = 0
saved_actions = model.saved_actions
policy_losses = []
value_losses = []
rewards = []
for r in model.rewards[::-1]:
R = r + gamma * R
rewards.insert(0, R)
rewards = torch.tensor(rewards)
rewards = (rewards - rewards.mean()) / (rewards.std() + eps)
for (log_prob, value), r in zip(saved_actions, rewards):
reward = r - value.item()
policy_losses.append(-log_prob * reward)
value_losses.append(F.smooth_l1_loss(value, torch.tensor([r])))
optimizer.zero_grad()
loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
loss.backward()
optimizer.step()
del model.rewards[:]
del model.saved_actions[:]


def main(env, args):

model = Policy()
optimizer = optim.Adam(model.parameters(), lr=3e-2)
eps = np.finfo(np.float32).eps.item()

def run_single_timestep(engine, _):
# Hack to avoid to run a single step when episode is done
# As ignite does not have a possibility to terminate earlier an epoch without stopping the training
if engine.state.done:
return

observation = engine.state.observation
action = select_action(model, observation)
engine.state.observation, reward, engine.state.done, _ = env.step(action)
if args.render:
env.render()
model.rewards.append(reward)

if engine.state.done:
# Here we need just to stop `_run_once_on_dataset` and not `run`
# engine.should_terminate = True
engine.state.timestep = engine.state.iteration % len(timesteps)
pass

trainer = Engine(run_single_timestep)

timesteps = list(range(10000))

@trainer.on(Events.STARTED)
def initialize(engine):
engine.state.running_reward = 10

@trainer.on(Events.EPOCH_STARTED)
def reset_environment_state(engine):
engine.state.observation = env.reset()
engine.state.done = False

@trainer.on(Events.EPOCH_COMPLETED)
def update_model(engine):
t = engine.state.timestep
engine.state.running_reward = engine.state.running_reward * 0.99 + t * 0.01
finish_episode(model, optimizer, args.gamma, eps)

@trainer.on(Events.EPOCH_COMPLETED)
def log_episode(engine):
i_episode = engine.state.epoch
if i_episode % args.log_interval == 0:
print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(
i_episode, engine.state.timestep, engine.state.running_reward))

@trainer.on(Events.EPOCH_COMPLETED)
def finish_training(engine):
running_reward = engine.state.running_reward
if running_reward > env.spec.reward_threshold:
print("Solved! Running reward is now {} and "
"the last episode runs to {} time steps!".format(running_reward, engine.state.timestep))
engine.should_terminate = True

trainer.run(timesteps, max_epochs=1000000)


if __name__ == '__main__':

parser = argparse.ArgumentParser(description='Ignite actor-critic example')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
help='random seed (default: 1)')
parser.add_argument('--render', action='store_true',
help='render the environment')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='interval between training status logs (default: 10)')
args = parser.parse_args()

env = gym.make('CartPole-v0')
env.seed(args.seed)
torch.manual_seed(args.seed)

main(env, args)