Skip to content

Commit b4cc350

Browse files
authoredJun 17, 2024··
Fix categorical data with external memory. (#10433)
1 parent a8ddbac commit b4cc350

File tree

5 files changed

+31
-7
lines changed

5 files changed

+31
-7
lines changed
 

‎demo/guide-python/external_memory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def make_batches(
4343
class Iterator(xgboost.DataIter):
4444
"""A custom iterator for loading files in batches."""
4545

46-
def __init__(self, file_paths: List[Tuple[str, str]]):
46+
def __init__(self, file_paths: List[Tuple[str, str]]) -> None:
4747
self._file_paths = file_paths
4848
self._it = 0
4949
# XGBoost will generate some cache files under current directory with the prefix

‎src/common/hist_util.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2017-2024 by XGBoost Contributors
2+
* Copyright 2017-2024, XGBoost Contributors
33
* \file hist_util.h
44
* \brief Utility for fast histogram aggregation
55
* \author Philip Cho, Tianqi Chen
@@ -11,7 +11,6 @@
1111
#include <cstdint> // for uint32_t
1212
#include <limits>
1313
#include <map>
14-
#include <memory>
1514
#include <utility>
1615
#include <vector>
1716

‎src/data/gradient_index.cc

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
*/
55
#include "gradient_index.h"
66

7-
#include <algorithm>
87
#include <limits>
98
#include <memory>
109
#include <utility> // for forward
@@ -126,8 +125,8 @@ INSTANTIATION_PUSH(data::ColumnarAdapterBatch)
126125
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
127126
auto make_index = [this, n_index](auto t, common::BinTypeSize t_size) {
128127
// Must resize instead of allocating a new one. This function is called everytime a
129-
// new batch is pushed, and we grow the size accordingly without loosing the data the
130-
// previous batches.
128+
// new batch is pushed, and we grow the size accordingly without loosing the data in
129+
// the previous batches.
131130
using T = decltype(t);
132131
std::size_t n_bytes = sizeof(T) * n_index;
133132
CHECK_GE(n_bytes, this->data.size());

‎src/data/histogram_cut_format.h

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2021-2023, XGBoost contributors
2+
* Copyright 2021-2024, XGBoost contributors
33
*/
44
#ifndef XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
55
#define XGBOOST_DATA_HISTOGRAM_CUT_FORMAT_H_
@@ -23,6 +23,15 @@ inline bool ReadHistogramCuts(common::HistogramCuts *cuts, common::AlignedResour
2323
if (!common::ReadVec(fi, &cuts->min_vals_.HostVector())) {
2424
return false;
2525
}
26+
bool has_cat{false};
27+
if (!fi->Read(&has_cat)) {
28+
return false;
29+
}
30+
decltype(cuts->MaxCategory()) max_cat{0};
31+
if (!fi->Read(&max_cat)) {
32+
return false;
33+
}
34+
cuts->SetCategorical(has_cat, max_cat);
2635
return true;
2736
}
2837

@@ -32,6 +41,8 @@ inline std::size_t WriteHistogramCuts(common::HistogramCuts const &cuts,
3241
bytes += common::WriteVec(fo, cuts.Values());
3342
bytes += common::WriteVec(fo, cuts.Ptrs());
3443
bytes += common::WriteVec(fo, cuts.MinValues());
44+
bytes += fo->Write(cuts.HasCategorical());
45+
bytes += fo->Write(cuts.MaxCategory());
3546
return bytes;
3647
}
3748
} // namespace xgboost::data

‎tests/python/test_data_iterator.py

+15
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,21 @@ def test_single_batch(tree_method: str = "approx") -> None:
5252
assert from_np.get_dump() == from_it.get_dump()
5353

5454

55+
def test_with_cat_single() -> None:
56+
X, y = tm.make_categorical(
57+
n_samples=128, n_features=3, n_categories=6, onehot=False
58+
)
59+
Xy = xgb.DMatrix(SingleBatch(data=X, label=y), enable_categorical=True)
60+
from_it = xgb.train({}, Xy, num_boost_round=3)
61+
62+
Xy = xgb.DMatrix(X, y, enable_categorical=True)
63+
from_Xy = xgb.train({}, Xy, num_boost_round=3)
64+
65+
jit = from_it.save_raw(raw_format="json")
66+
jxy = from_Xy.save_raw(raw_format="json")
67+
assert jit == jxy
68+
69+
5570
def run_data_iterator(
5671
n_samples_per_batch: int,
5772
n_features: int,

0 commit comments

Comments
 (0)
Please sign in to comment.