Skip to content

Commit

Permalink
Improve example filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Aug 17, 2017
1 parent 498b711 commit 92e054b
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 141 deletions.
8 changes: 0 additions & 8 deletions config/models/nmt_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,3 @@ def model():
cell_class=tf.contrib.rnn.LSTMCell,
dropout=0.3,
residual_connections=False))

def train(model):
model.set_filters(
maximum_source_length=70,
maximum_target_length=70)

def infer(model):
pass
6 changes: 0 additions & 6 deletions config/models/pos_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,3 @@ def model():
residual_connections=False),
labels_vocabulary_file="data/wsj/tags.txt",
crf_decoding=True)

def train(model):
model.set_filters(maximum_length=70)

def infer(model):
pass
21 changes: 0 additions & 21 deletions config/models/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,3 @@ def model():
A `opennmt.models.Model`.
"""
pass

def train(model):
"""Run training specific code.
You usually call methods on `model` to set training specific
attributes, e.g. the longest sequence lengths accepted.
Args:
model: The model previously built.
"""
pass

def infer(model):
"""Run inference specific code.
Similar to `train` but for inference.
Args:
model: The model previously built.
"""
pass
8 changes: 0 additions & 8 deletions config/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,3 @@ def model():
num_heads=8,
ffn_inner_dim=2048,
dropout=0.1)

def train(model):
model.set_filters(
maximum_source_length=70,
maximum_target_length=70)

def infer(model):
pass
6 changes: 6 additions & 0 deletions config/train.sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ data:
eval_features_file: "data/en-test.txt"
eval_labels_file: "data/fr-test.txt"

# (optional) The maximum length of feature sequences during training (if it applies).
maximum_features_length: 70

# (optional) The maximum length of label sequences during training (if it applies).
maximum_labels_length: 70

# (optional) The pre-fetch buffer size (e.g. for shuffling examples).
buffer_size: 10000

Expand Down
8 changes: 3 additions & 5 deletions onmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def main():
params=params)

if config["run"]["type"] == "train":
model_config.train(model)

train_input_fn = model.input_fn(
tf.estimator.ModeKeys.TRAIN,
config["params"]["batch_size"],
buffer_size,
num_buckets,
config["data"]["train_features_file"],
labels_file=config["data"]["train_labels_file"])
labels_file=config["data"]["train_labels_file"],
maximum_features_length=config["data"].get("maximum_features_length"),
maximum_labels_length=config["data"].get("maximum_labels_length"))
eval_input_fn = model.input_fn(
tf.estimator.ModeKeys.EVAL,
config["params"]["batch_size"],
Expand All @@ -116,8 +116,6 @@ def main():
else:
experiment.train()
else:
model_config.infer(model)

test_input_fn = model.input_fn(
tf.estimator.ModeKeys.PREDICT,
config["params"]["batch_size"],
Expand Down
77 changes: 56 additions & 21 deletions opennmt/models/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Base class for models."""

import tensorflow as tf

import abc
import six
import time

import tensorflow as tf


@six.add_metaclass(abc.ABCMeta)
Expand Down Expand Up @@ -37,21 +38,43 @@ def _build_train_op(self, loss, params):

return train_op

def _filter_example(self, features, labels):
def _filter_example(self,
features,
labels,
maximum_features_length=None,
maximum_labels_length=None):
"""Defines an example filtering condition."""
return True
features_length = self.features_length(features)
labels_length = self.labels_length(labels)

@abc.abstractmethod
def _get_size(self, features, labels):
"""Defines a size to an example for data bucketing."""
raise NotImplementedError()
cond = []

if features_length is not None:
cond.append(tf.greater(features_length, 0))
if maximum_features_length is not None:
cond.append(tf.less_equal(features_length, maximum_features_length))

if labels_length is not None:
cond.append(tf.greater(labels_length, 0))
if maximum_labels_length is not None:
cond.append(tf.less_equal(labels_length, maximum_labels_length))

return tf.reduce_all(cond)

def _get_maximum_size(self):
"""Defines the maximum size of an example for data bucketing."""
def features_length(self, features):
"""Attributes a length to a feature (if defined)."""
return None

def labels_length(self, labels):
"""Attributes a length to a label (if defined)."""
return None

@abc.abstractmethod
def _build_dataset(self, mode, batch_size, features_file, labels_file=None):
def _build_dataset(self,
mode,
batch_size,
features_file,
labels_file=None):
"""Builds a dataset from features and labels files.
Args:
Expand All @@ -71,7 +94,9 @@ def _input_fn_impl(self,
buffer_size,
num_buckets,
features_file,
labels_file=None):
labels_file=None,
maximum_features_length=None,
maximum_labels_length=None):
"""See `input_fn`."""
dataset, padded_shapes = self._build_dataset(
mode,
Expand All @@ -80,8 +105,12 @@ def _input_fn_impl(self,
labels_file=labels_file)

if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.filter(self._filter_example)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.filter(lambda features, labels: self._filter_example(
features,
labels,
maximum_features_length=maximum_features_length,
maximum_labels_length=maximum_labels_length))
dataset = dataset.shuffle(buffer_size, seed=int(time.time()))
dataset = dataset.repeat()
elif mode == tf.estimator.ModeKeys.EVAL:
dataset = dataset.repeat()
Expand All @@ -94,14 +123,12 @@ def _input_fn_impl(self,
# For training and evaluation, use bucketing.

def key_func(features, labels):
maximum_size = self._get_maximum_size()

if maximum_size:
bucket_width = (maximum_size + num_buckets - 1) // num_buckets
if maximum_features_length:
bucket_width = (maximum_features_length + num_buckets - 1) // num_buckets
else:
bucket_width = 10

bucket_id = self._get_size(features, labels) // bucket_width
bucket_id = self.features_length(features) // bucket_width
bucket_id = tf.minimum(bucket_id, num_buckets)
return tf.to_int64(bucket_id)

Expand All @@ -128,7 +155,9 @@ def input_fn(self,
buffer_size,
num_buckets,
features_file,
labels_file=None):
labels_file=None,
maximum_features_length=None,
maximum_labels_length=None):
"""Returns an input function.
See also `tf.estimator.Estimator`.
Expand All @@ -140,6 +169,10 @@ def input_fn(self,
num_buckets: The number of buckets to store examples of similar sizes.
features_file: The file containing input features.
labels_file: The file containing output labels.
maximum_features_length: The maximum length of feature sequences
during training (if it applies).
maximum_labels_length: The maximum length of label sequences
during training (if it applies).
Returns:
A callable that returns the next element.
Expand All @@ -153,7 +186,9 @@ def input_fn(self,
buffer_size,
num_buckets,
features_file,
labels_file=labels_file)
labels_file=labels_file,
maximum_features_length=maximum_features_length,
maximum_labels_length=maximum_labels_length)

def format_prediction(self, prediction, params=None):
"""Formats the model prediction.
Expand Down
21 changes: 1 addition & 20 deletions opennmt/models/sequence_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,10 @@ def __init__(self,
self.encoder = encoder
self.labels_vocabulary_file = labels_vocabulary_file
self.num_labels = count_lines(labels_vocabulary_file)
self.maximum_length = 0

def set_filters(self, maximum_length):
self.maximum_length = maximum_length

def _get_size(self, features, labels):
def features_length(self, features):
return self.embedder.get_data_field(features, "length")

def _get_maximum_size(self):
return getattr(self, "maximum_length", None)

def _filter_example(self, features, labels):
"""Filters examples with invalid length."""
cond = tf.greater(self.embedder.get_data_field(features, "length"), 0)

if self.maximum_length > 0:
cond = tf.logical_and(
cond,
tf.less_equal(self.embedder.get_data_field(features, "length"),
self.maximum_length))

return cond

def _build_dataset(self, mode, batch_size, features_file, labels_file=None):
features_dataset = self.embedder.make_dataset(features_file)

Expand Down
21 changes: 1 addition & 20 deletions opennmt/models/sequence_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,15 @@ def __init__(self,
self.labels_vocabulary_file = labels_vocabulary_file
self.num_labels = count_lines(labels_vocabulary_file)
self.crf_decoding = crf_decoding
self.maximum_length = 0

self.id_to_label = []
with open(labels_vocabulary_file) as labels_vocabulary:
for label in labels_vocabulary:
self.id_to_label.append(label.strip())

def set_filters(self, maximum_length):
self.maximum_length = maximum_length

def _get_size(self, features, labels):
def features_length(self, features):
return self.embedder.get_data_field(features, "length")

def _get_maximum_size(self):
return getattr(self, "maximum_length", None)

def _filter_example(self, features, labels):
"""Filters examples with invalid length."""
cond = tf.greater(self.embedder.get_data_field(features, "length"), 0)

if self.maximum_length > 0:
cond = tf.logical_and(
cond,
tf.less_equal(self.embedder.get_data_field(features, "length"),
self.maximum_length))

return cond

def _build_dataset(self, mode, batch_size, features_file, labels_file=None):
features_dataset = self.embedder.make_dataset(features_file)

Expand Down
35 changes: 3 additions & 32 deletions opennmt/models/sequence_to_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,40 +27,11 @@ def __init__(self,
self.source_embedder.set_name("source")
self.target_embedder.set_name("target")

self.maximum_source_length = 0
self.maximum_target_length = 0

def set_filters(self,
maximum_source_length,
maximum_target_length):
self.maximum_source_length = maximum_source_length
self.maximum_target_length = maximum_target_length

def _get_size(self, features, labels):
def features_length(self, features):
return self.source_embedder.get_data_field(features, "length")

def _get_maximum_size(self):
return getattr(self, "maximum_source_length", None)

def _filter_example(self, features, labels):
"""Filters examples with invalid length."""
cond = tf.logical_and(
tf.greater(self.source_embedder.get_data_field(features, "length"), 0),
tf.greater(self.target_embedder.get_data_field(labels, "length"), 0))

if self.maximum_source_length > 0:
cond = tf.logical_and(
cond,
tf.less_equal(self.source_embedder.get_data_field(features, "length"),
self.maximum_source_length))

if self.maximum_target_length > 0:
cond = tf.logical_and(
cond,
tf.less_equal(self.target_embedder.get_data_field(labels, "length"),
self.maximum_target_length + 1)) # "+ 1" because <s> was already added.

return cond
def labels_length(self, labels):
return self.target_embedder.get_data_field(labels, "length")

def _shift_target(self, labels):
"""Generate shifted target sequences with <s> and </s>."""
Expand Down

0 comments on commit 92e054b

Please sign in to comment.