Skip to content

Commit

Permalink
issue #17: Hyperparameter search LSTM with peep_holes added
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermevarela committed Sep 28, 2018
1 parent c6df2e6 commit 57b2b40
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
10 changes: 8 additions & 2 deletions models/propagators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@


def get_unit(sz, ru='BasicLSTM'):
if ru not in ('BasicLSTM', 'GRU'):
raise ValueError('recurrent_unit {:} must be in {:}'.format (ru, ('BasicLSTM', 'GRU')))
ru_types = ('BasicLSTM', 'GRU', 'LSTM')
if ru not in ru_types:
raise ValueError('recurrent_unit {:} must be in {:}'.format (ru, ru_types))

if ru == 'BasicLSTM':
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(sz,
Expand All @@ -19,6 +20,11 @@ def get_unit(sz, ru='BasicLSTM'):
if ru == 'GRU':
rnn_cell = tf.nn.rnn_cell.GRUCell(sz)

if ru == 'LSTM':
rnn_cell = tf.nn.rnn_cell.LSTMCell(sz,
use_peepholes=False,
forget_bias=1.0,
state_is_tuple=True)
return rnn_cell


Expand Down
2 changes: 1 addition & 1 deletion srl.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
Default: 0.005\n''')

parser.add_argument('--ru', dest='ru', type=str,
default='BasicLSTM', choices=('BasicLSTM', 'GRU'),
default='BasicLSTM', choices=('BasicLSTM', 'GRU', 'LSTM'),
help='''Recurrent unit -- according to tensorflow.
Default: `BasicLSTM`\n''')

Expand Down

0 comments on commit 57b2b40

Please sign in to comment.