Skip to content

Commit

Permalink
Bugfix for polynomial learning rate schedule (LLNL#1984)
Browse files Browse the repository at this point in the history
* Make polynomial learning rate schedule match Keras

* Add unit test for polynomial learning rate callback

* Use synthetic data reader in test for polynomial decay learning rate schedule

* Make polynomial learning rate schedule match Keras

* Add unit test for polynomial learning rate callback

* Use synthetic data reader in test for polynomial decay learning rate schedule
  • Loading branch information
Tim Moon committed Nov 2, 2021
1 parent fda4fd0 commit c88901e
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 18 deletions.
161 changes: 161 additions & 0 deletions bamboo/unit_tests/test_unit_callback_poly_learning_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Test to check polynomial decay learning rate schedule.
LBANN is run with the polynomial learning rate schedule and the log
files are post-processed to make sure that the correct learning rate
values are used.
"""
import os
import os.path
import random
import re
import sys

# Bamboo utilities
current_file = os.path.realpath(__file__)
current_dir = os.path.dirname(current_file)
sys.path.insert(0, os.path.join(os.path.dirname(current_dir), 'common_python'))
import tools

# ==============================================
# Learning rate schedule parameters
# ==============================================

lr_power = 0.8
lr_num_epochs = 5
lr_start = 1
lr_end = 0.1

# ==============================================
# Setup LBANN experiment
# ==============================================

def setup_experiment(lbann):
"""Construct LBANN experiment.
Args:
lbann (module): Module for LBANN Python frontend
"""
mini_batch_size = 1
trainer = lbann.Trainer(mini_batch_size)
model = construct_model(lbann)
data_reader = construct_data_reader(lbann)
optimizer = lbann.SGD(learn_rate=lr_start)
return trainer, model, data_reader, optimizer

def construct_model(lbann):
"""Construct LBANN model.
Args:
lbann (module): Module for LBANN Python frontend
"""

# Layer graph
x = lbann.Input(data_field='samples')
x = lbann.FullyConnected(x, num_neurons=1)

# Model objects
metrics = []
callbacks = [
lbann.CallbackPolyLearningRate(
power=lr_power,
num_epochs=lr_num_epochs,
end_lr=lr_end,
),
]

# Construct model
return lbann.Model(lr_num_epochs+2,
layers=x,
metrics=metrics,
callbacks=callbacks)

def construct_data_reader(lbann):
"""Construct Protobuf message for Python data reader.
The Python data reader will import the current Python file to
access the sample access functions.
Args:
lbann (module): Module for LBANN Python frontend
"""
message = lbann.reader_pb2.DataReader()
_reader = message.reader.add()
_reader.name = 'synthetic'
_reader.role = 'train'
_reader.num_samples = 2
_reader.synth_dimensions = '1'
_reader.percent_of_data_to_use = 1.0
return message

# ==============================================
# Setup PyTest
# ==============================================

def augment_test_func(test_func):
"""Augment test function to parse log files.
`tools.create_tests` creates functions that run an LBANN
experiment. This function creates augmented functions that parse
the log files after LBANN finishes running, e.g. to check metrics
or runtimes.
Note: The naive approach is to define the augmented test functions
in a loop. However, Python closures are late binding. In other
words, the function would be overwritten every time we define it.
We get around this overwriting problem by defining the augmented
function in the local scope of another function.
Args:
test_func (function): Test function created by
`tools.create_tests`.
Returns:
function: Test that can interact with PyTest.
"""
test_name = test_func.__name__

# Define test function
def func(cluster, dirname):

# Run LBANN experiment
experiment_output = test_func(cluster, dirname)

# Parse LBANN log file
lr_list = []
log_file = experiment_output['stdout_log_file']
with open(log_file) as f:
for line in f:
match = re.search(
'changing global learning rate to ([0-9.]+)',
line)
if match:
lr_list.append(float(match.group(1)))

# Make sure file has been parsed correctly
assert len(lr_list) == lr_num_epochs, \
f'Error parsing {log_file} ' \
f'(expected {lr_num_epochs} learning rates, ' \
f'but found {len(lr_list)})'

# Make sure learning rates match expected values
tol = 1e-5
for epoch in range(lr_num_epochs):
lr = lr_list[epoch]
scale = (1 - (epoch+1)/lr_num_epochs) ** lr_power
expected_lr = (lr_start - lr_end) * scale + lr_end
assert expected_lr-tol < lr < expected_lr+tol, \
f'Incorrect learning rate at epoch {epoch}' \
f'(expected {expected_lr}, but found {lr})'

# Return test function from factory function
func.__name__ = test_name
return func

# Create test functions that can interact with PyTest
for _test_func in tools.create_tests(setup_experiment, __file__,):
globals()[_test_func.__name__] = augment_test_func(_test_func)
8 changes: 3 additions & 5 deletions include/lbann/callbacks/learning_rate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,10 @@ class poly_learning_rate : public learning_rate {
size_t m_num_epochs;
/// The maximum number of iterations until which the learning rate changes
size_t m_max_iter;
/// The minimum learning rate
/// The initial learning rate
float m_start_lr;
/// The final learning rate
float m_end_lr;
/// The current rate to scale the base learning rate
float m_lr;
/// The learning rate scale used at the end of the last epoch
float m_last_epoch_lr;
};

// Builder function
Expand Down
27 changes: 14 additions & 13 deletions src/callbacks/learning_rate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,22 +251,23 @@ poly_learning_rate::poly_learning_rate(
double p, size_t n_epochs, size_t max_iter)
: learning_rate(std::vector<std::string>()),
m_p(p), m_num_epochs(n_epochs), m_max_iter(max_iter),
m_end_lr(0.0f),
m_lr(1.0f), m_last_epoch_lr(1.0f) {}
m_start_lr(0.0f), m_end_lr(0.0f)
{}

poly_learning_rate::poly_learning_rate(
double p, size_t n_epochs, size_t max_iter, double end_lr, std::vector<std::string> weights_names)
: learning_rate(std::move(weights_names)),
m_p(p), m_num_epochs(n_epochs), m_max_iter(max_iter),
m_end_lr(end_lr),
m_lr(1.0f), m_last_epoch_lr(1.0f) {}
m_start_lr(0.0f), m_end_lr(end_lr)
{}

/**
* Check if the maximum number of iterations is set. If not, compute it by the
* number of epochs and the number of iterations per epoch.
*/
void poly_learning_rate::setup(model *m) {
learning_rate::setup(m);
m_start_lr = get_current_global_learning_rate();
if (m_max_iter == 0ull) {
data_coordinator& dc = get_trainer().get_data_coordinator();
m_max_iter = m_num_epochs * dc.get_num_iterations_per_epoch(execution_mode::training);
Expand All @@ -277,22 +278,22 @@ void poly_learning_rate::setup(model *m) {
* Keep the record of the learning rate at the end of the current epoch.
*/
float poly_learning_rate::global_schedule(model *m) {
const float scale = m_lr / m_last_epoch_lr;
m_last_epoch_lr = m_lr;
return (poly_learning_rate::get_current_global_learning_rate() - m_end_lr) * scale + m_end_lr;
const auto& c = static_cast<const SGDExecutionContext&>(m->get_execution_context());
const size_t iter = std::min(c.get_step(), m_max_iter);
const float scale = static_cast<float>(
std::pow(static_cast<double>(m_max_iter-iter)/m_max_iter, m_p));
return (m_start_lr - m_end_lr) * scale + m_end_lr;
}

/**
* Compute the learning rate for the next iteration.
*/
float poly_learning_rate::optimizer_schedule(model *m, optimizer &opt) {
const auto& c = static_cast<const SGDExecutionContext&>(m->get_execution_context());
const size_t cur_iter = c.get_step();
if (m_max_iter > cur_iter) {
m_lr = static_cast<float>(std::pow(static_cast<double>(m_max_iter - cur_iter)/m_max_iter, m_p));
}
const float scale = m_lr / m_last_epoch_lr;
return (poly_learning_rate::get_current_global_learning_rate() - m_end_lr) * scale + m_end_lr;
const size_t iter = std::min(c.get_step(), m_max_iter);
const float scale = static_cast<float>(
std::pow(static_cast<double>(m_max_iter-iter)/m_max_iter, m_p));
return (m_start_lr - m_end_lr) * scale + m_end_lr;
}

optimizerwise_adaptive_learning_rate::
Expand Down

0 comments on commit c88901e

Please sign in to comment.