Skip to content

Simplify DaskScikitLearnBase.predict #11411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

TomAugspurger
Copy link

This PR simplifies the DaskScikitLearnBase.predict method. The main change is to use map_partitions on the input rather than the current code which is hitting dask/distributed#8998.

Opening this up early for feedback and to get it tested in CI. I've tested locally with dask==2024.9.1, 2024.12.1, and 2025.2.0., and everything in tests/test_distributed/test_with_dask/test_with_dask.py passes. Unfortunately, there are failures on dask@main, but those are different from the current set of failures.

Closes #10994

@@ -1,4 +1,4 @@
## Update the following line to test changes to CI images
## See https://xgboost.readthedocs.io/en/latest/contrib/ci.html#making-changes-to-ci-containers

IMAGE_TAG=main
IMAGE_TAG=PR-14
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using CI images from dmlc/xgboost-devops#14 to test the PR with latest Rapids and Dask

@TomAugspurger
Copy link
Author

Linting issues should be fixed now.

@TomAugspurger
Copy link
Author

I forgot to mention, there's probably a bunch of code that can be cleaned up if we go this route. I haven't attempted to do that yet: just trying to get the tests passing.

iteration_range=iteration_range,
)

is_regression = isinstance(self, XGBRegressorBase)

meta: numpy.typing.NDArray[numpy.float32]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meta here could be a cupy.ndarray too. Haven't looked into how to handle that yet.

Copy link
Author

@TomAugspurger TomAugspurger Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type of Booster.predict is np.ndarray, so using np.ndarray here matches that. However, that's just the static type annotation. At runtime, IIUC, we could get a cupy array here.

I think we'll need to inspect the type of X._meta and figure out the return type based on that.

@hcho3
Copy link
Collaborator

hcho3 commented Apr 17, 2025

Getting a Dask error:

ValueError: Cannot fuse tasks with multiple outputs {('any-7d7fe3a28f3a95da73683893b772f06f', 0, 0), ('_to_string_dtype-fused-values-b0183e1b96012ab6e674d647fcc98df3', 0)}

https://github.com/dmlc/xgboost/actions/runs/14502244978/job/40686317395?pr=11411

@TomAugspurger
Copy link
Author

I'll get a GPU env set up to reproduce that locally. @hcho3 do you know if the package versions (specifically cudf and dask-cudf) are printed out anywhere for that job?

@rjzamora, I know we saw that "Cannot fuse tasks with multiple outputs" error in rapidsai/dask-upstream-testing#37. Do you remember whether rapidsai/cudf#18382 fixed that?

@hcho3
Copy link
Collaborator

hcho3 commented Apr 17, 2025

@TomAugspurger The environment can be found at https://github.com/dmlc/xgboost-devops/actions/runs/14502381551/job/40684836721?pr=14.

From the build log of the CI container:

dask-core                             2025.2.0  pyhd8ed1ab_0                  conda-forge      968kB
dask                                  2025.2.0  pyhd8ed1ab_0                  conda-forge        8kB
dask-cudf                               25.4.0  cuda12_py310_250409_6bc42063  rapidsai          87kB
dask-cuda                             25.04.00  py310_250409_ge9ebd92_0       rapidsai         212kB
libcudf                                 25.4.0  cuda12_250409_6bc42063        rapidsai         357MB
pylibcudf                               25.4.0  cuda12_py310_250409_6bc42063  rapidsai           4MB
cudf                                    25.4.0  cuda12_py310_250409_6bc42063  rapidsai           1MB

@rjzamora
Copy link
Contributor

@rjzamora, I know we saw that "Cannot fuse tasks with multiple outputs" error in rapidsai/dask-upstream-testing#37. Do you remember whether rapidsai/cudf#18382 fixed that?

Yes, hopefully cudf-25.06 includes the necessary fix for that error.

@TomAugspurger
Copy link
Author

TomAugspurger commented Apr 17, 2025

I'm having some trouble getting a GPU environment locally :/

Is there an easy way to test with cudf-25.06 in CI? I suspect that everything will work with that.


I did cmake -B build -S . -DUSE_CUDA=ON -GNinja and cd build && ninja and that eventually failed with

FAILED: src/CMakeFiles/objxgboost.dir/data/ellpack_page.cu.o 
/raid/toaugspurger/envs/xgboost-dev/bin/nvcc -forward-unknown-to-host-compiler -ccbin=/usr/bin/c++ -DDMLC_CORE_USE_CMAKE -DDMLC_LOG_CUSTOMIZE=1 -DDMLC_USE_CXX11=1 -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_CUDA -DTHRUST_HOST_SYSTEM=THRUST_HOST_SYSTEM_CPP -DXGBOOST_BUILTIN_PREFETCH_PRESENT=1 -DXGBOOST_MM_PREFETCH_PRESENT=1 -DXGBOOST_USE_CUDA=1 -D_MWAITXINTRIN_H_INCLUDED -D__USE_XOPEN2K8 -I/home/nfs/toaugspurger/gh/dmlc/xgboost/include -I/home/nfs/toaugspurger/gh/dmlc/xgboost/dmlc-core/include -I/home/nfs/toaugspurger/gh/dmlc/xgboost/gputreeshap -I/home/nfs/toaugspurger/gh/dmlc/xgboost/build/dmlc-core/include -isystem /raid/toaugspurger/envs/xgboost-dev/targets/x86_64-linux/include -O3 -DNDEBUG -std=c++17 "--generate-code=arch=compute_50,code=[sm_50]" "--generate-code=arch=compute_60,code=[sm_60]" "--generate-code=arch=compute_70,code=[sm_70]" "--generate-code=arch=compute_80,code=[sm_80]" "--generate-code=arch=compute_90,code=[sm_90]" "--generate-code=arch=compute_100,code=[sm_100]" "--generate-code=arch=compute_120,code=[sm_120]" "--generate-code=arch=compute_120,code=[compute_120]" -Xcompiler=-fPIC --expt-extended-lambda --expt-relaxed-constexpr -Xcompiler=-fopenmp -Xfatbin=-compress-all --default-stream per-thread -lineinfo -MD -MT src/CMakeFiles/objxgboost.dir/data/ellpack_page.cu.o -MF src/CMakeFiles/objxgboost.dir/data/ellpack_page.cu.o.d -x cu -c /home/nfs/toaugspurger/gh/dmlc/xgboost/src/data/ellpack_page.cu -o src/CMakeFiles/objxgboost.dir/data/ellpack_page.cu.o
nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
/home/nfs/toaugspurger/gh/dmlc/xgboost/src/data/ellpack_page.cu:34:1: error: function ‘xgboost::EllpackPage::~EllpackPage()’ defaulted on its redeclaration with an exception-specification that differs from the implicit exception-specification ‘noexcept’
   34 | EllpackPage::~EllpackPage() noexcept(false) = default;
      | ^~~~~~~~~~~
ninja: build stopped: subcommand failed.

I fixed (maybe? I'm a C++ novice) that locally with this diff

diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu
index c3926cff1..6299f9f86 100644
--- a/src/data/ellpack_page.cu
+++ b/src/data/ellpack_page.cu
@@ -31,7 +31,7 @@ EllpackPage::EllpackPage() : impl_{new EllpackPageImpl{}} {}
 EllpackPage::EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param)
     : impl_{new EllpackPageImpl{ctx, dmat, param}} {}
 
-EllpackPage::~EllpackPage() noexcept(false) = default;
+EllpackPage::~EllpackPage() noexcept(false) {}
 
 EllpackPage::EllpackPage(EllpackPage&& that) { std::swap(impl_, that.impl_); }

Then

cd python-package
python install .

worked, but it unexpectedly(?) downloaded nvidia-nccl-cu12 from PyPI. I expected that to already be in my conda env:

conda list nccl
# packages in environment at /raid/toaugspurger/envs/xgboost-dev:
#
# Name                    Version                   Build  Channel
nccl                      2.26.2.1             ha44e49d_1    conda-forge
nvidia-nccl-cu12          2.26.2.post1             pypi_0    pypi

And the test fails with

>   raise XGBoostError(py_str(_LIB.XGBGetLastError()))
E   xgboost.core.XGBoostError: [08:40:42] /home/nfs/toaugspurger/gh/dmlc/xgboost/src/collective/coll.cc:141: NCCL is required for device communication.
E   Stack trace:
E     [bt] (0) /raid/toaugspurger/envs/xgboost-dev/lib/python3.12/site-packages/xgboost/lib/libxgboost.so(+0x4868f8) [0x7fc1888818f8]
E     [bt] (1) /raid/toaugspurger/envs/xgboost-dev/lib/python3.12/site-packages/xgboost/lib/libxgboost.so(xgboost::collective::Coll::MakeCUDAVar()+0x38) [0x7fc188881988]
E     [bt] (2) /raid/toaugspurger/envs/xgboost-dev/lib/python3.12/site-packages/xgboost/lib/libxgboost.so(xgboost::collective::CommGroup::Backend(xgboost::DeviceOrd) const+0x8a) [0x7fc1888a7aca]
E     [bt] (3) /raid/toaugspurger/envs/xgboost-dev/lib/python3.12/site-packages/xgboost/lib/libxgboost.so(xgboost::collective::Result xgboost::collective::GlobalSum<float, 1>(xgboost::Context const*, bool, xgboost::linalg::TensorView<float, 1>)+0x137) [0x7fc1889250c7]
E     [bt] (4) /raid/toaugspurger/envs/xgboost-dev/lib/python3.12/site-packages/xgboost/lib/libxgboost.so(xgboost::common::cuda_impl::SampleMean(xgboost::Context const*, bool, xgboost::linalg::TensorView<float const, 2>, xgboost::linalg::TensorView<float, 1>)+0x7e8) [0x7fc189100e28]
E     [bt] (5) /raid/toaugspurger/envs/xgboost-dev/lib/python3.12/site-packages/xgboost/lib/libxgboost.so(xgboost::common::SampleMean(xgboost::Context const*, bool, xgboost::linalg::Tensor<float, 2> const&, xgboost::linalg::Tensor<float, 1>*)+0xf07) [0x7fc1889afb77]
E     [bt] (6) /raid/toaugspurger/envs/xgboost-dev/lib/python3.12/site-packages/xgboost/lib/libxgboost.so(xgboost::obj::FitInterceptGlmLike::InitEstimation(xgboost::MetaInfo const&, xgboost::linalg::Tensor<float, 1>*) const+0xa6) [0x7fc188c95176]
E     [bt] (7) /raid/toaugspurger/envs/xgboost-dev/lib/python3.12/site-packages/xgboost/lib/libxgboost.so(void xgboost::collective::ApplyWithLabels<float, xgboost::LearnerConfiguration::InitEstimation(xgboost::MetaInfo const&, xgboost::linalg::Tensor<float, 1>*)::{lambda()#1}>(xgboost::Context const*, xgboost::MetaInfo const&, xgboost::HostDeviceVector<float>*, xgboost::LearnerConfiguration::InitEstimation(xgboost::MetaInfo const&, xgboost::linalg::Tensor<float, 1>*)::{lambda()#1}&&)+0x164) [0x7fc188beccd4]
E     [bt] (8) /raid/toaugspurger/envs/xgboost-dev/lib/python3.12/site-packages/xgboost/lib/libxgboost.so(xgboost::LearnerConfiguration::InitBaseScore(xgboost::DMatrix const*)+0x192) [0x7fc188bf11a2]

/raid/toaugspurger/envs/xgboost-dev/lib/python3.12/site-packages/xgboost/core.py:321: XGBoostError

@trivialfis
Copy link
Member

trivialfis commented Apr 17, 2025

I remember an error like the one you posted and concluded it was a compiler bug, but I couldn't remember which compiler it was. For now, please keep the workaround local.

@trivialfis
Copy link
Member

trivialfis commented Apr 17, 2025

For most all cases in Dask, the self._predict_sync should do something similar to map_partitions/blocks as it goes through the inplace_predict function:

predts = await inplace_predict(

The map_partitions is called here:

predictions = dd.map_partitions(

The Dask interface should be using the inplace_predict since gblinear doesn't support distributed training, and only the gblinear doesn't support in place predict:

def _can_use_inplace_predict(self) -> bool:
.

@TomAugspurger
Copy link
Author

@trivialfisjust just confirming: you would expect that DaskXGBClassifer.predict / Regressor.predict would call into the top-level xgboost.dask.predict (or .inplace_predict)? I'll start on that.

And I haven't been able to get a CUDA build working, still hitting that issue with NCCL. If anyone is able to, I think that failing test will pass with any recent cudf / libcudf / pylibcudf from https://pypi.anaconda.org/rapidsai-wheels-nightly/simple/ and this branch.

@TomAugspurger
Copy link
Author

Mmm this is complicated by xgboost.dask.predict accepting a DaskDMatrix object. I'm not familiar with how that works, but it doesn't seem straightforward to support in the same way as DataFrame or Array inputs.

@TomAugspurger
Copy link
Author

I probably don't have enough context to suggest a good path forward here, but I think that ideally the predict method would try to do less. Maybe just:

  • determine the output meta (shape, dtype, etc.)
  • call X.map_{blocks,partitions}(booster.predict, ...)

But I'm probably missing some important reasons for the way things are setup like they are currently.

@trivialfis
Copy link
Member

Apologies for the slow response, will look into it tomorrow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Test fails with Dask 2024.11.0+
4 participants