Skip to content

Commit

Permalink
Use state_is_tuple=True
Browse files Browse the repository at this point in the history
  • Loading branch information
futurecrew committed Nov 15, 2016
1 parent 7045b41 commit 9b77f09
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 26 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/snapshot/
/output/
6 changes: 4 additions & 2 deletions arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_args():
parser = argparse.ArgumentParser()

parser.add_argument('rom', type=str, help='ALE rom file')
parser.add_argument('--asynchronousRL', type=bool, default=True, help='')
parser.add_argument('--asynchronousRL', action='store_true', help='whether to use asynchronous learning')
parser.add_argument('--asynchronousRL-type', type=str, default='A3C_LSTM', help='A3C_LSTM, A3C, 1Q')
parser.add_argument('--multi-thread-no', type=int, default=1, help='Number of multiple threads for Asynchronous RL')
parser.add_argument('--network-type', type=str, default='nips', help='network model nature or nips')
Expand All @@ -22,6 +22,9 @@ def get_args():
parser.add_argument('--replay-file', type=str, default=None, help='trained file to replay')
parser.add_argument('--device', type=str, default='', help='(gpu, cpu)')
parser.add_argument('--env', type=str, default='ale', help='environment(ale, vizdoom)')
parser.add_argument('--show-screen', action='store_true', help='whether to show display or not')
parser.set_defaults(asynchronousRL=True)
parser.set_defaults(show_screen=False)

args = parser.parse_args()

Expand All @@ -30,7 +33,6 @@ def get_args():
args.screen_height = 84 # input screen height
args.screen_history = 4 # input screen history
args.frame_repeat = 4 # how many frames to repeat in ale for one predicted action
args.show_screen = False # whether to show ale display
args.use_ale_frame_skip = False # whether to use ale frame_skip feature
args.discount_factor = 0.99 # RL discount factor
args.test_step = 125000 # test for this number of steps
Expand Down
17 changes: 6 additions & 11 deletions deep_rl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, args, play_file=None, thread_no=0, global_list=None):

def initialize_post(self):
""" initialization that should be run on __init__() or after deserialization """
if (self.args.show_screen and self.thread_no == 0) or self.play_file is not None:
if self.args.show_screen and self.thread_no == 0:
display_screen = True
else:
display_screen = False
Expand Down Expand Up @@ -279,11 +279,7 @@ def do_actions(self, action_index, mode):
game_over = True
break
new_state = np.maximum(prev_state, self.current_state)

try:
resized = cv2.resize(new_state, (self.args.screen_height, self.args.screen_width))
except:
pass
resized = cv2.resize(new_state, (self.args.screen_height, self.args.screen_width))
return reward, resized, terminal, game_over

def generate_replay_memory(self, count):
Expand Down Expand Up @@ -467,8 +463,9 @@ def train_async_a3c(self, replay_memory_no=None):
learning_rate = self._anneal_learning_rate(max_global_step_no, global_step_no)

if self.args.asynchronousRL_type == 'A3C_LSTM':
self.model_runner.set_lstm_state(lstm_state_value)
self.model_runner.train(prestates, v_pres, actions, rewards, terminals, v_post, learning_rate)
self.model_runner.train(prestates, v_pres, actions, rewards, terminals, v_post, learning_rate, lstm_state_value)
else:
self.model_runner.train(prestates, v_pres, actions, rewards, terminals, v_post, learning_rate)

self.train_step += 1

Expand Down Expand Up @@ -577,8 +574,6 @@ def get_env(args, initialize, show_screen):
save_file = args.retrain_file

if args.asynchronousRL == True:
global global_step_no

threadList = []
playerList = []

Expand Down Expand Up @@ -628,7 +623,7 @@ def get_env(args, initialize, show_screen):

# copy global variables to local variables
for i in range(args.multi_thread_no):
playerList[i].model_runner.copy_from_global_to_local()
playerList[i].model_runner.copy_from_global_to_local()
else:
for i in range(args.multi_thread_no):
print 'creating a thread[%s]' % i
Expand Down
26 changes: 16 additions & 10 deletions model_runner_tf_a3c_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def init_models(self):
self.y_class = self.model.y_class
self.v = self.model.v

self.lstm_state = self.model.lstm_state
self.lstm_init_state = self.model.lstm_init_state
self.lstm_next_state = self.model.lstm_next_state
self.sequence_length = self.model.sequence_length
self.lstm_hidden_size = 256
Expand Down Expand Up @@ -78,7 +78,8 @@ def get_loss(self):
def predict_action_state(self, state):
y_class, v, self.lstm_state_value = self.sess.run([self.y_class, self.v, self.lstm_next_state], feed_dict={
self.x_in: state,
self.lstm_state: self.lstm_state_value,
self.lstm_init_state.c : self.lstm_state_value[0],
self.lstm_init_state.h : self.lstm_state_value[1],
self.sequence_length : [1],
})
return y_class[0], v[0]
Expand All @@ -88,30 +89,34 @@ def predict_state(self, state):
# because this function does not process to a new screen frame
v, _ = self.sess.run([self.v, self.lstm_next_state], feed_dict={
self.x_in: state,
self.lstm_state: self.lstm_state_value,
self.lstm_init_state.c : self.lstm_state_value[0],
self.lstm_init_state.h : self.lstm_state_value[1],
self.sequence_length : [1],
})
return v[0]

def predict(self, state):
y_class, self.lstm_state_value = self.sess.run([self.y_class, self.lstm_next_state], feed_dict={
self.x_in: state,
self.lstm_state: self.lstm_state_value,
self.lstm_init_state.c : self.lstm_state_value[0],
self.lstm_init_state.h : self.lstm_state_value[1],
self.sequence_length : [1],
})
return y_class[0]

def reset_lstm_state(self):
self.lstm_state_value = np.zeros((1, self.lstm_hidden_size * 2))
self.lstm_state_value = self.sess.run(self.lstm_init_state)

def get_lstm_state(self):
return self.lstm_state_value.copy()
return self.lstm_state_value

def set_lstm_state(self, lstm_state_value):
self.lstm_state_value = lstm_state_value
def set_lstm_state(self, lstm_state_value_c, lstm_state_value_h):
self.lstm_state_value[0] = lstm_state_value_c
self.lstm_state_value[1] = lstm_state_value_h

def train(self, prestates, v_pres, actions, rewards, terminals, v_post, learning_rate):
def train(self, prestates, v_pres, actions, rewards, terminals, v_post, learning_rate, lstm_state_value):
data_len = len(actions)
self.lstm_state_value = lstm_state_value

action_mat = np.zeros((data_len, self.max_action_no), dtype=np.uint8)
v_in = np.zeros(data_len)
Expand All @@ -136,7 +141,8 @@ def train(self, prestates, v_pres, actions, rewards, terminals, v_post, learning
self.a_in: action_mat,
self.v_in: v_in,
self.td_in: td_in,
self.lstm_state : self.lstm_state_value,
self.lstm_init_state.c : self.lstm_state_value[0],
self.lstm_init_state.h : self.lstm_state_value[1],
self.sequence_length : [data_len],
})

Expand Down
6 changes: 3 additions & 3 deletions network_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,12 @@ def build_network_nips(self, name, trainable, num_actions):

with tf.variable_scope('LSTM'):
hidden_size = 256
self.lstm_state = tf.placeholder(tf.float32, (1, hidden_size * 2))
self.sequence_length = tf.placeholder(tf.int32)
cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size, forget_bias=1.0)
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size, forget_bias=1.0, state_is_tuple=True)
self.lstm_init_state = lstm_cell.zero_state(1, tf.float32)
h_fc1_reshape = tf.reshape(h_fc1, [-1, 1, 256])
print h_fc1_reshape
outputs, self.lstm_next_state = tf.nn.dynamic_rnn(cell, h_fc1_reshape, initial_state=self.lstm_state, sequence_length=self.sequence_length, time_major=True)
outputs, self.lstm_next_state = tf.nn.dynamic_rnn(lstm_cell, h_fc1_reshape, initial_state=self.lstm_init_state, sequence_length=self.sequence_length, time_major=True)
print('outputs : %s' % outputs) # (5, 1, 256)
outputs = tf.squeeze(outputs, [1]) # (5, 256)
print('outputs : %s' % outputs)
Expand Down

0 comments on commit 9b77f09

Please sign in to comment.