forked from OpenNMT/OpenNMT-tf
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create a model catalog in the library (OpenNMT#102)
- Loading branch information
1 parent
791da62
commit 2158e29
Showing
21 changed files
with
413 additions
and
317 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,3 @@ | ||
"""Defines a character-based sequence-to-sequence model. | ||
from opennmt.models.catalog import CharacterSeq2Seq | ||
|
||
Character vocabularies can be built with: | ||
python -m bin.build_vocab --tokenizer CharacterTokenizer ... | ||
""" | ||
|
||
import tensorflow as tf | ||
import opennmt as onmt | ||
|
||
def model(): | ||
return onmt.models.SequenceToSequence( | ||
source_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="source_chars_vocabulary", | ||
embedding_size=30, | ||
tokenizer=onmt.tokenizers.CharacterTokenizer()), | ||
target_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="target_chars_vocabulary", | ||
embedding_size=30, | ||
tokenizer=onmt.tokenizers.CharacterTokenizer()), | ||
encoder=onmt.encoders.BidirectionalRNNEncoder( | ||
num_layers=4, | ||
num_units=512, | ||
reducer=onmt.layers.ConcatReducer(), | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False), | ||
decoder=onmt.decoders.AttentionalRNNDecoder( | ||
num_layers=4, | ||
num_units=512, | ||
bridge=onmt.layers.CopyBridge(), | ||
attention_mechanism_class=tf.contrib.seq2seq.LuongAttention, | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False)) | ||
model = CharacterSeq2Seq |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,3 @@ | ||
"""Defines a model similar to the "Listen, Attend and Spell" model described | ||
in https://arxiv.org/abs/1508.01211. | ||
""" | ||
from opennmt.models.catalog import ListenAttendSpell | ||
|
||
import tensorflow as tf | ||
import opennmt as onmt | ||
|
||
def model(): | ||
return onmt.models.SequenceToSequence( | ||
source_inputter=onmt.inputters.SequenceRecordInputter(), | ||
target_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="target_vocabulary", | ||
embedding_size=50), | ||
encoder=onmt.encoders.PyramidalRNNEncoder( | ||
num_layers=3, | ||
num_units=512, | ||
reduction_factor=2, | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3), | ||
decoder=onmt.decoders.MultiAttentionalRNNDecoder( | ||
num_layers=3, | ||
num_units=512, | ||
attention_layers=[0], | ||
attention_mechanism_class=tf.contrib.seq2seq.LuongMonotonicAttention, | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False)) | ||
model = ListenAttendSpell |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,3 @@ | ||
"""Defines a sequence to sequence model with multiple input features. For | ||
example, this could be words, parts of speech, and lemmas that are embedded in | ||
parallel and concatenated into a single input embedding. The features are | ||
separate data files with separate vocabularies. | ||
""" | ||
from opennmt.models.catalog import MultiFeaturesNMT | ||
|
||
import tensorflow as tf | ||
import opennmt as onmt | ||
|
||
def model(): | ||
return onmt.models.SequenceToSequence( | ||
source_inputter=onmt.inputters.ParallelInputter([ | ||
onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="source_words_vocabulary", | ||
embedding_size=512), | ||
onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="feature_1_vocabulary", | ||
embedding_size=16), | ||
onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="feature_2_vocabulary", | ||
embedding_size=64)], | ||
reducer=onmt.layers.ConcatReducer()), | ||
target_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="target_words_vocabulary", | ||
embedding_size=512), | ||
encoder=onmt.encoders.BidirectionalRNNEncoder( | ||
num_layers=4, | ||
num_units=512, | ||
reducer=onmt.layers.ConcatReducer(), | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False), | ||
decoder=onmt.decoders.AttentionalRNNDecoder( | ||
num_layers=4, | ||
num_units=512, | ||
bridge=onmt.layers.CopyBridge(), | ||
attention_mechanism_class=tf.contrib.seq2seq.LuongAttention, | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False)) | ||
model = MultiFeaturesNMT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,3 @@ | ||
"""Defines a multi source sequence to sequence model. Source sequences are read | ||
from 2 files, encoded separately, and the encoder outputs are concatenated in | ||
time. | ||
""" | ||
from opennmt.models.catalog import MultiSourceNMT | ||
|
||
import tensorflow as tf | ||
import opennmt as onmt | ||
|
||
def model(): | ||
return onmt.models.SequenceToSequence( | ||
source_inputter=onmt.inputters.ParallelInputter([ | ||
onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="source_vocabulary_1", | ||
embedding_size=512), | ||
onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="source_vocabulary_2", | ||
embedding_size=512)]), | ||
target_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="target_vocabulary", | ||
embedding_size=512), | ||
encoder=onmt.encoders.ParallelEncoder([ | ||
onmt.encoders.BidirectionalRNNEncoder( | ||
num_layers=2, | ||
num_units=512, | ||
reducer=onmt.layers.ConcatReducer(), | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False), | ||
onmt.encoders.BidirectionalRNNEncoder( | ||
num_layers=2, | ||
num_units=512, | ||
reducer=onmt.layers.ConcatReducer(), | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False)], | ||
outputs_reducer=onmt.layers.ConcatReducer(axis=1)), | ||
decoder=onmt.decoders.AttentionalRNNDecoder( | ||
num_layers=4, | ||
num_units=512, | ||
bridge=onmt.layers.DenseBridge(), | ||
attention_mechanism_class=tf.contrib.seq2seq.LuongAttention, | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False)) | ||
model = MultiSourceNMT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,3 @@ | ||
"""Defines a medium-sized bidirectional LSTM encoder-decoder model.""" | ||
from opennmt.models.catalog import NMTMedium | ||
|
||
import tensorflow as tf | ||
import opennmt as onmt | ||
|
||
def model(): | ||
return onmt.models.SequenceToSequence( | ||
source_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="source_words_vocabulary", | ||
embedding_size=512), | ||
target_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="target_words_vocabulary", | ||
embedding_size=512), | ||
encoder=onmt.encoders.BidirectionalRNNEncoder( | ||
num_layers=4, | ||
num_units=512, | ||
reducer=onmt.layers.ConcatReducer(), | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False), | ||
decoder=onmt.decoders.AttentionalRNNDecoder( | ||
num_layers=4, | ||
num_units=512, | ||
bridge=onmt.layers.CopyBridge(), | ||
attention_mechanism_class=tf.contrib.seq2seq.LuongAttention, | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False)) | ||
model = NMTMedium |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,3 @@ | ||
"""Defines a medium-sized bidirectional LSTM encoder-decoder model with | ||
experimental FP16 data type. | ||
""" | ||
from opennmt.models.catalog import NMTMediumFP16 | ||
|
||
import tensorflow as tf | ||
import opennmt as onmt | ||
|
||
def model(): | ||
return onmt.models.SequenceToSequence( | ||
source_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="source_words_vocabulary", | ||
embedding_size=512, | ||
dtype=tf.float16), | ||
target_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="target_words_vocabulary", | ||
embedding_size=512, | ||
dtype=tf.float16), | ||
encoder=onmt.encoders.BidirectionalRNNEncoder( | ||
num_layers=4, | ||
num_units=512, | ||
reducer=onmt.layers.ConcatReducer(), | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False), | ||
decoder=onmt.decoders.AttentionalRNNDecoder( | ||
num_layers=4, | ||
num_units=512, | ||
bridge=onmt.layers.CopyBridge(), | ||
attention_mechanism_class=tf.contrib.seq2seq.LuongAttention, | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False)) | ||
model = NMTMediumFP16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,3 @@ | ||
"""Defines a small unidirectional LSTM encoder-decoder model.""" | ||
from opennmt.models.catalog import NMTSmall | ||
|
||
import tensorflow as tf | ||
import opennmt as onmt | ||
|
||
def model(): | ||
return onmt.models.SequenceToSequence( | ||
source_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="source_words_vocabulary", | ||
embedding_size=512), | ||
target_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="target_words_vocabulary", | ||
embedding_size=512), | ||
encoder=onmt.encoders.UnidirectionalRNNEncoder( | ||
num_layers=2, | ||
num_units=512, | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False), | ||
decoder=onmt.decoders.AttentionalRNNDecoder( | ||
num_layers=2, | ||
num_units=512, | ||
bridge=onmt.layers.CopyBridge(), | ||
attention_mechanism_class=tf.contrib.seq2seq.LuongAttention, | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.3, | ||
residual_connections=False)) | ||
model = NMTSmall |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,3 @@ | ||
"""Defines a bidirectional LSTM-CNNs-CRF as described in https://arxiv.org/abs/1603.01354.""" | ||
from opennmt.models.catalog import SeqTagger | ||
|
||
import tensorflow as tf | ||
import opennmt as onmt | ||
|
||
def model(): | ||
return onmt.models.SequenceTagger( | ||
inputter=onmt.inputters.MixedInputter([ | ||
onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="words_vocabulary", | ||
embedding_size=None, | ||
embedding_file_key="words_embedding", | ||
trainable=True), | ||
onmt.inputters.CharConvEmbedder( | ||
vocabulary_file_key="chars_vocabulary", | ||
embedding_size=30, | ||
num_outputs=30, | ||
kernel_size=3, | ||
stride=1, | ||
dropout=0.5)], | ||
dropout=0.5), | ||
encoder=onmt.encoders.BidirectionalRNNEncoder( | ||
num_layers=1, | ||
num_units=400, | ||
reducer=onmt.layers.ConcatReducer(), | ||
cell_class=tf.contrib.rnn.LSTMCell, | ||
dropout=0.5, | ||
residual_connections=False), | ||
labels_vocabulary_file_key="tags_vocabulary", | ||
crf_decoding=True) | ||
model = SeqTagger |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,3 @@ | ||
"""Defines a Transformer model as decribed in https://arxiv.org/abs/1706.03762.""" | ||
from opennmt.models.catalog import Transformer | ||
|
||
import tensorflow as tf | ||
import opennmt as onmt | ||
|
||
def model(): | ||
return onmt.models.Transformer( | ||
source_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="source_words_vocabulary", | ||
embedding_size=512), | ||
target_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="target_words_vocabulary", | ||
embedding_size=512), | ||
num_layers=6, | ||
num_units=512, | ||
num_heads=8, | ||
ffn_inner_dim=2048, | ||
dropout=0.1, | ||
attention_dropout=0.1, | ||
relu_dropout=0.1) | ||
model = Transformer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,3 @@ | ||
"""Defines a Transformer model with experimental FP16 data type.""" | ||
from opennmt.models.catalog import TransformerFP16 | ||
|
||
import tensorflow as tf | ||
import opennmt as onmt | ||
|
||
def model(): | ||
return onmt.models.Transformer( | ||
source_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="source_words_vocabulary", | ||
embedding_size=512, | ||
dtype=tf.float16), | ||
target_inputter=onmt.inputters.WordEmbedder( | ||
vocabulary_file_key="target_words_vocabulary", | ||
embedding_size=512, | ||
dtype=tf.float16), | ||
num_layers=6, | ||
num_units=512, | ||
num_heads=8, | ||
ffn_inner_dim=2048, | ||
dropout=0.1, | ||
attention_dropout=0.1, | ||
relu_dropout=0.1) | ||
model = TransformerFP16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.