Skip to content

Commit

Permalink
[EM] Avoid stream sync in quantile sketching. (#10765)
Browse files Browse the repository at this point in the history
.
  • Loading branch information
trivialfis committed Aug 30, 2024
1 parent 61dd854 commit 34d4ab4
Show file tree
Hide file tree
Showing 12 changed files with 313 additions and 313 deletions.
16 changes: 15 additions & 1 deletion src/common/algorithm.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2022-2023 by XGBoost Contributors
* Copyright 2022-2024, XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_ALGORITHM_CUH_
#define XGBOOST_COMMON_ALGORITHM_CUH_
Expand Down Expand Up @@ -258,5 +258,19 @@ void ArgSort(xgboost::Context const *ctx, xgboost::common::Span<U> keys,
sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice,
cuctx->Stream()));
}

template <typename InIt, typename OutIt, typename Predicate>
void CopyIf(CUDAContext const *cuctx, InIt in_first, InIt in_second, OutIt out_first,
Predicate pred) {
// We loop over batches because thrust::copy_if can't deal with sizes > 2^31
// See thrust issue #1302, XGBoost #6822
size_t constexpr kMaxCopySize = std::numeric_limits<int>::max() / 2;
size_t length = std::distance(in_first, in_second);
for (size_t offset = 0; offset < length; offset += kMaxCopySize) {
auto begin_input = in_first + offset;
auto end_input = in_first + std::min(offset + kMaxCopySize, length);
out_first = thrust::copy_if(cuctx->CTP(), begin_input, end_input, out_first, pred);
}
}
} // namespace xgboost::common
#endif // XGBOOST_COMMON_ALGORITHM_CUH_
36 changes: 5 additions & 31 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -637,12 +637,11 @@ struct SegmentedUniqueReduceOp {
* \return Number of unique values in total.
*/
template <typename DerivedPolicy, typename KeyInIt, typename KeyOutIt, typename ValInIt,
typename ValOutIt, typename CompValue, typename CompKey>
size_t
SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
CompValue comp, CompKey comp_key=thrust::equal_to<size_t>{}) {
typename ValOutIt, typename CompValue, typename CompKey = thrust::equal_to<size_t>>
size_t SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
CompValue comp, CompKey comp_key = thrust::equal_to<size_t>{}) {
using Key = thrust::pair<size_t, typename thrust::iterator_traits<ValInIt>::value_type>;
auto unique_key_it = dh::MakeTransformIterator<Key>(
thrust::make_counting_iterator(static_cast<size_t>(0)),
Expand Down Expand Up @@ -676,16 +675,6 @@ SegmentedUnique(const thrust::detail::execution_policy_base<DerivedPolicy> &exec
return n_uniques;
}

template <typename... Inputs,
std::enable_if_t<std::tuple_size<std::tuple<Inputs...>>::value == 7>
* = nullptr>
size_t SegmentedUnique(Inputs &&...inputs) {
dh::XGBCachingDeviceAllocator<char> alloc;
return SegmentedUnique(thrust::cuda::par(alloc),
std::forward<Inputs &&>(inputs)...,
thrust::equal_to<size_t>{});
}

/**
* \brief Unique by key for many groups of data. Has same constraint as `SegmentedUnique`.
*
Expand Down Expand Up @@ -793,21 +782,6 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
#endif
}

template <typename InIt, typename OutIt, typename Predicate>
void CopyIf(InIt in_first, InIt in_second, OutIt out_first, Predicate pred) {
// We loop over batches because thrust::copy_if can't deal with sizes > 2^31
// See thrust issue #1302, XGBoost #6822
size_t constexpr kMaxCopySize = std::numeric_limits<int>::max() / 2;
size_t length = std::distance(in_first, in_second);
XGBCachingDeviceAllocator<char> alloc;
for (size_t offset = 0; offset < length; offset += kMaxCopySize) {
auto begin_input = in_first + offset;
auto end_input = in_first + std::min(offset + kMaxCopySize, length);
out_first = thrust::copy_if(thrust::cuda::par(alloc), begin_input,
end_input, out_first, pred);
}
}

template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) {
InclusiveScan(d_in, d_out, cub::Sum(), num_items);
Expand Down
73 changes: 38 additions & 35 deletions src/common/hist_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,27 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_idx_t num_ro
return std::min(sketch_batch_num_elements, kIntMax);
}

void SortByWeight(dh::device_vector<float>* weights, dh::device_vector<Entry>* sorted_entries) {
void SortByWeight(Context const* ctx, dh::device_vector<float>* weights,
dh::device_vector<Entry>* sorted_entries) {
// Sort both entries and wegihts.
dh::XGBDeviceAllocator<char> alloc;
auto cuctx = ctx->CUDACtx();
CHECK_EQ(weights->size(), sorted_entries->size());
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(), sorted_entries->end(),
weights->begin(), detail::EntryCompareOp());
thrust::sort_by_key(cuctx->TP(), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
detail::EntryCompareOp());

// Scan weights
dh::XGBCachingDeviceAllocator<char> caching;
thrust::inclusive_scan_by_key(
thrust::cuda::par(caching), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
cuctx->CTP(), sorted_entries->begin(), sorted_entries->end(), weights->begin(),
weights->begin(),
[=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; });
}

void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst_idx_t> d_cuts_ptr,
void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info,
Span<bst_idx_t> d_cuts_ptr,
dh::device_vector<Entry>* p_sorted_entries,
dh::device_vector<float>* p_sorted_weights,
dh::caching_device_vector<size_t>* p_column_sizes_scan) {
info.feature_types.SetDevice(device);
info.feature_types.SetDevice(ctx->Device());
auto d_feature_types = info.feature_types.ConstDeviceSpan();
CHECK(!d_feature_types.empty());
auto& column_sizes_scan = *p_column_sizes_scan;
Expand All @@ -142,30 +143,32 @@ void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst
auto d_sorted_weights = dh::ToSpan(*p_sorted_weights);
auto val_in_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
auto val_out_it = thrust::make_zip_iterator(d_sorted_entries.data(), d_sorted_weights.data());
n_uniques = dh::SegmentedUnique(
column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(),
val_in_it, val_in_it + sorted_entries.size(), new_column_scan.data().get(), val_out_it,
[=] __device__(Pair const& l, Pair const& r) {
Entry const& le = thrust::get<0>(l);
Entry const& re = thrust::get<0>(r);
if (le.index == re.index && IsCat(d_feature_types, le.index)) {
return le.fvalue == re.fvalue;
}
return false;
});
n_uniques =
dh::SegmentedUnique(ctx->CUDACtx()->CTP(), column_sizes_scan.data().get(),
column_sizes_scan.data().get() + column_sizes_scan.size(), val_in_it,
val_in_it + sorted_entries.size(), new_column_scan.data().get(),
val_out_it, [=] __device__(Pair const& l, Pair const& r) {
Entry const& le = thrust::get<0>(l);
Entry const& re = thrust::get<0>(r);
if (le.index == re.index && IsCat(d_feature_types, le.index)) {
return le.fvalue == re.fvalue;
}
return false;
});
p_sorted_weights->resize(n_uniques);
} else {
n_uniques = dh::SegmentedUnique(
column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(),
sorted_entries.begin(), sorted_entries.end(), new_column_scan.data().get(),
sorted_entries.begin(), [=] __device__(Entry const& l, Entry const& r) {
if (l.index == r.index) {
if (IsCat(d_feature_types, l.index)) {
return l.fvalue == r.fvalue;
}
}
return false;
});
n_uniques = dh::SegmentedUnique(ctx->CUDACtx()->CTP(), column_sizes_scan.data().get(),
column_sizes_scan.data().get() + column_sizes_scan.size(),
sorted_entries.begin(), sorted_entries.end(),
new_column_scan.data().get(), sorted_entries.begin(),
[=] __device__(Entry const& l, Entry const& r) {
if (l.index == r.index) {
if (IsCat(d_feature_types, l.index)) {
return l.fvalue == r.fvalue;
}
}
return false;
});
}
sorted_entries.resize(n_uniques);

Expand All @@ -189,7 +192,7 @@ void RemoveDuplicatedCategories(DeviceOrd device, MetaInfo const& info, Span<bst
}
});
// Turn size into ptr.
thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(), new_cuts_size.cend(),
thrust::exclusive_scan(ctx->CUDACtx()->CTP(), new_cuts_size.cbegin(), new_cuts_size.cend(),
d_cuts_ptr.data());
}
} // namespace detail
Expand Down Expand Up @@ -225,7 +228,7 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
std::size_t ridx = dh::SegmentId(row_ptrs, element_idx);
d_temp_weight[idx] = sample_weight[ridx + base_rowid];
});
detail::SortByWeight(&entry_weight, &sorted_entries);
detail::SortByWeight(ctx, &entry_weight, &sorted_entries);
} else {
thrust::sort(cuctx->TP(), sorted_entries.begin(), sorted_entries.end(),
detail::EntryCompareOp());
Expand All @@ -238,21 +241,21 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c
sorted_entries.data().get(), [] __device__(Entry const& e) -> data::COOTuple {
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
});
detail::GetColumnSizesScan(ctx->Device(), info.num_col_, num_cuts_per_feature,
detail::GetColumnSizesScan(ctx->CUDACtx(), ctx->Device(), info.num_col_, num_cuts_per_feature,
IterSpan{batch_it, sorted_entries.size()}, dummy_is_valid, &cuts_ptr,
&column_sizes_scan);
auto d_cuts_ptr = cuts_ptr.DeviceSpan();
if (sketch_container->HasCategorical()) {
auto p_weight = entry_weight.empty() ? nullptr : &entry_weight;
detail::RemoveDuplicatedCategories(ctx->Device(), info, d_cuts_ptr, &sorted_entries, p_weight,
detail::RemoveDuplicatedCategories(ctx, info, d_cuts_ptr, &sorted_entries, p_weight,
&column_sizes_scan);
}

auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());

// Add cuts into sketches
sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
sketch_container->Push(ctx, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), d_cuts_ptr,
h_cuts_ptr.back(), dh::ToSpan(entry_weight));

sorted_entries.clear();
Expand Down
Loading

0 comments on commit 34d4ab4

Please sign in to comment.