Skip to content

Commit

Permalink
Move state from data reader to data coordinator (LLNL#1744)
Browse files Browse the repository at this point in the history
* Mini-batch size is now passed into the data reader fetch functions
from the data coordinator, rather than having the data reader
calculate what is the current mini-batch size.

* Removed jag_partitioned field since it was deprecated by sample lists

* Removed the set_and is_master functions and m_master field since that is available from get_comm()->am_world_master()

* Removed cached value for rank which really should be rank_in_trainer and was labled as rank in model

* Switched data readers over to using the global get_trainer() function call to find the trainer

* Updated data readers to use get_rank_in_trainer rather than local field

* Fixed typo

* Updated tests to pass number of samples
  • Loading branch information
bvanessen committed Oct 14, 2021
1 parent 10e00a3 commit 2458baa
Show file tree
Hide file tree
Showing 23 changed files with 98 additions and 234 deletions.
14 changes: 0 additions & 14 deletions include/lbann/data_readers/compound_data_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,6 @@ class generic_compound_data_reader : public generic_data_reader {
}
}

void set_master(bool m) override {
generic_data_reader::set_master(m);
for (auto&& reader : m_data_readers) {
reader->set_master(m);
}
}

void set_rank(int rank) override {
generic_data_reader::set_rank(rank);
for (auto&& reader : m_data_readers) {
reader->set_rank(rank);
}
}

/// needed to support data_store_merge_samples
std::vector<generic_data_reader*> & get_data_readers() {
return m_data_readers;
Expand Down
67 changes: 7 additions & 60 deletions include/lbann/data_readers/data_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ class generic_data_reader {
public:
using unused_index_map_t = std::map<execution_mode,std::vector<int>>;

#define JAG_NOOP_VOID if (m_jag_partitioned) { return; }
#define JAG_NOOP_INT if (m_jag_partitioned) { return 0; }

/**
* ctor
*/
Expand All @@ -97,7 +94,6 @@ class generic_data_reader {
m_global_last_mini_batch_size(0),
m_world_master_mini_batch_adjustment(0),
m_num_parallel_readers(0),
m_rank_in_model(0),
m_max_files_to_load(0),
m_file_dir(""),
m_data_sample_list(""),
Expand All @@ -106,14 +102,10 @@ class generic_data_reader {
m_shuffle(shuffle),
m_absolute_sample_count(0),
m_use_percent(1.0),
m_master(false),
m_gan_labelling(false), // default, not GAN
m_gan_label_value(
0), // If GAN, default for fake label, discriminator model
m_gan_labelling(false), //default, not GAN
m_gan_label_value(0), //If GAN, default for fake label, discriminator model
m_io_thread_pool(nullptr),
m_jag_partitioned(false),
m_keep_sample_order(false),
m_trainer(nullptr),
m_issue_warning(true)
{
// By default only support fetching input samples
Expand All @@ -131,7 +123,6 @@ class generic_data_reader {
/// set the comm object
void set_comm(lbann_comm *comm) {
m_comm = comm;
set_master(comm->am_world_master());
}

/// returns a (possibly nullptr) to comm
Expand Down Expand Up @@ -306,7 +297,7 @@ class generic_data_reader {

/** @brief Fetch a mini-batch worth of data, including samples, labels, responses (as appropriate) */
int fetch(std::map<data_field_type, CPUMat*>& input_buffers,
El::Matrix<El::Int>& indices_fetched);
El::Matrix<El::Int>& indices_fetched, size_t mb_size);

/** @brief Check to see if the data reader supports this specific data field
*/
Expand Down Expand Up @@ -448,7 +439,6 @@ class generic_data_reader {
}
/// Set the mini batch size across all models (global)
void set_global_mini_batch_size(const int s) {
JAG_NOOP_VOID
m_global_mini_batch_size = s;
}
/// Return the mini_batch_size across all models (global)
Expand All @@ -457,7 +447,6 @@ class generic_data_reader {
}
/// Set the mini batch stride
void set_stride_to_next_mini_batch(const int s) {
JAG_NOOP_VOID
m_stride_to_next_mini_batch = s;
}
/// Return the mini batch stride.
Expand All @@ -466,7 +455,6 @@ class generic_data_reader {
}
/// Set the sample stride
void set_sample_stride(const int s) {
JAG_NOOP_VOID
m_sample_stride = s;
}
/// Return the sample stride.
Expand All @@ -483,7 +471,6 @@ class generic_data_reader {
}
/// Return the base offset.
virtual void set_base_offset(const int s) {
JAG_NOOP_VOID
m_base_offset = s;
}
/// Return the base offset.
Expand All @@ -492,7 +479,6 @@ class generic_data_reader {
}
/// Set the model offset
void set_model_offset(const int s) {
JAG_NOOP_VOID
m_model_offset = s;
}
/// Return the model offset.
Expand All @@ -501,7 +487,6 @@ class generic_data_reader {
}
/// Set the last mini batch size
void set_last_mini_batch_size(const int s) {
JAG_NOOP_VOID
m_last_mini_batch_size = s;
}
/// Return the last mini batch size
Expand All @@ -510,7 +495,6 @@ class generic_data_reader {
}
/// Set the last mini batch size across all models (global)
void set_global_last_mini_batch_size(const int s) {
JAG_NOOP_VOID
m_global_last_mini_batch_size = s;
}
/// Return the last mini batch size across all models (global)
Expand All @@ -519,7 +503,6 @@ class generic_data_reader {
}
/// Set the world master mini batch adjustment (global)
void set_world_master_mini_batch_adjustment(const int s) {
JAG_NOOP_VOID
m_world_master_mini_batch_adjustment = s;
}
/// Return the world master mini batch adjustment (global)
Expand All @@ -528,7 +511,6 @@ class generic_data_reader {
}
/// Set the last mini batch stride
void set_stride_to_last_mini_batch(const int s) {
JAG_NOOP_VOID
m_stride_to_last_mini_batch = s;
}
/// Return the last mini batch stride
Expand Down Expand Up @@ -616,26 +598,6 @@ class generic_data_reader {
return m_current_mini_batch_idx;
}

/// only the master may write to cerr or cout; primarily for use in debugging during development
virtual void set_master(bool m) {
m_master = m;
}

/// only the master may write to cerr or cout; primarily for use in debugging during development
bool is_master() const {
return m_master;
}

/// Allow the reader to know where it is in the model hierarchy
virtual void set_rank(int rank) {
m_rank_in_model = rank;
}

/// Allow the reader to know where it is in the model hierarchy
int get_rank() const {
return m_rank_in_model;
}

/**
* Optionally resizes the shuffled indices based on the data reader
* prototext settings: absolute_sample_count, percent_of_data_to_use.
Expand Down Expand Up @@ -714,12 +676,9 @@ class generic_data_reader {

virtual bool priming_data_store() const;

void set_trainer(trainer *t) { m_trainer = t; }

trainer& get_trainer() const {
if(m_trainer == nullptr) { LBANN_ERROR("get_trainer called with nullptr"); }
return *m_trainer;
}
/// experimental; used to ensure all readers for jag_conduit_hdf5
/// have identical shuffled indices
virtual void post_update() {}

/** Set the transform pipeline this data reader will use. */
void set_transform_pipeline(transform::transform_pipeline&& tp) {
Expand Down Expand Up @@ -836,6 +795,7 @@ class generic_data_reader {
/// Shuffle indices and profide a random number generator
virtual void shuffle_indices(rng_gen& gen);

public:
int m_mini_batch_size;
int m_current_pos;
/// Batch Stride is typically batch_size, but may be a multiple of batch size if there are multiple readers
Expand Down Expand Up @@ -873,7 +833,6 @@ class generic_data_reader {

int m_num_parallel_readers; /// How many parallel readers are being used

int m_rank_in_model; /// What is the rank of the data reader within a given model
size_t m_max_files_to_load;
std::string m_file_dir;
std::string m_local_file_dir;
Expand All @@ -887,8 +846,6 @@ class generic_data_reader {
int m_first_n;
std::string m_role;

bool m_master;

/** @brief Print the return values from various get_X methods to file
*
* For use in unit testing. Only the master prints.
Expand Down Expand Up @@ -928,20 +885,10 @@ class generic_data_reader {

observer_ptr<thread_pool> m_io_thread_pool;

/// special handling for 1B jag; each reader
/// owns a unique subset of the data
bool m_jag_partitioned;

/** Whether to keep the order of loaded samples same as it is in the
* file to make testing and validation easier */
bool m_keep_sample_order;

/// called by fetch_data a single time if m_jag_partitioned = true;
/// this sets various member variables (num_iterations, m_reset_mini_batch_index,
/// etc.
void set_jag_variables(int mb_size);
trainer *m_trainer;

/** Transform pipeline for preprocessing data. */
transform::transform_pipeline m_transform_pipeline;

Expand Down
2 changes: 0 additions & 2 deletions include/lbann/data_readers/data_reader_jag_conduit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,6 @@ class data_reader_jag_conduit : public generic_data_reader {
bool has_list_per_model() const override { return m_list_per_model; }
bool has_list_per_trainer() const override { return m_list_per_trainer; }



/// Return the number of measurement views
unsigned int get_num_img_srcs() const;
/// Return the linearized size of an image
Expand Down
8 changes: 4 additions & 4 deletions include/lbann/data_readers/data_reader_sample_list_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void data_reader_sample_list<SampleListT>::shuffle_indices(rng_gen& gen)
template <typename SampleListT>
void data_reader_sample_list<SampleListT>::load()
{
if (is_master()) {
if (get_comm()->am_world_master()) {
std::cout << "starting data_reader_sample_list::load()\n";
}
const std::string sample_list_file = get_data_sample_list();
Expand Down Expand Up @@ -129,7 +129,7 @@ void data_reader_sample_list<SampleListT>::load_list_of_samples(
else {
m_sample_list.load(sample_list_file, *(this->m_comm), true);
}
if (is_master()) {
if (get_comm()->am_world_master()) {
std::cout << "Time to load sample list '" << sample_list_file
<< "': " << get_time() - tm1 << std::endl;
}
Expand All @@ -138,7 +138,7 @@ void data_reader_sample_list<SampleListT>::load_list_of_samples(
double tm3 = get_time();
m_sample_list.all_gather_packed_lists(*m_comm);

if (is_master()) {
if (get_comm()->am_world_master()) {
std::cout << "Time to gather sample list '" << sample_list_file
<< "': " << get_time() - tm3 << std::endl;
}
Expand All @@ -160,7 +160,7 @@ void data_reader_sample_list<SampleListT>::load_list_of_samples_from_archive(
iarchive(m_sample_list); // Read the data from the archive
double tm2 = get_time();

if (is_master()) {
if (get_comm()->am_world_master()) {
std::cout << "Time to load sample list from archive: " << tm2 - tm1
<< std::endl;
}
Expand Down
2 changes: 1 addition & 1 deletion src/callbacks/debug_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void debug_io::print_phase_start(model *m, execution_mode mode) {
generic_data_reader* data_reader = dc.get_data_reader(mode);
const auto& step = c.get_step();

if(data_reader->get_rank() < data_reader->get_num_parallel_readers()) {
if(m->get_comm()->get_rank_in_trainer() < data_reader->get_num_parallel_readers()) {
std::cout << "[" << m->get_comm()->get_trainer_rank()
<< "." << m->get_comm()->get_rank_in_trainer()
<< "] @" << 0 << "." << step
Expand Down
10 changes: 9 additions & 1 deletion src/data_coordinator/buffered_data_coordinator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,16 @@ int buffered_data_coordinator<TensorDataType>::fetch_to_local_matrix(data_buffer
for(auto& b : buf.m_input_buffers) {
local_input_buffers[b.first] = static_cast<CPUMat*>(&(b.second->Matrix()));
}

// Compute the size of the current mini-batch

int loaded_batch_size = dr->get_loaded_mini_batch_size();
const int end_pos = std::min(static_cast<size_t>(dr->m_current_pos+loaded_batch_size), dr->m_shuffled_indices.size());
const int mb_size = std::min(El::Int{((end_pos - dr->m_current_pos) + dr->m_sample_stride - 1) / dr->m_sample_stride},
local_input_buffers[INPUT_DATA_TYPE_SAMPLES]->Width());

/** @brief Each rank will fetch a mini-batch worth of data into it's buffer */
buf.m_num_samples_fetched = dr->fetch(local_input_buffers, buf.m_indices_fetched_per_mb);
buf.m_num_samples_fetched = dr->fetch(local_input_buffers, buf.m_indices_fetched_per_mb, mb_size);

bool data_valid = (buf.m_num_samples_fetched > 0);
if(data_valid) {
Expand Down
1 change: 0 additions & 1 deletion src/data_coordinator/data_coordinator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ void data_coordinator::setup(thread_pool& io_thread_pool, int max_mini_batch_siz
if (!dr.second) continue;
dr.second->setup(m_io_thread_pool->get_num_threads(),
m_io_thread_pool);
dr.second->set_rank(m_comm->get_rank_in_trainer());
}

/** Calculate how many iterations are required for training, testing,
Expand Down
Loading

0 comments on commit 2458baa

Please sign in to comment.