# pylint: disable=too-many-arguments, too-many-locals
# pylint: disable=missing-class-docstring, invalid-name
# pylint: disable=too-many-lines
"""
Dask extensions for distributed training
----------------------------------------

See :doc:`Distributed XGBoost with Dask </tutorials/dask>` for simple tutorial.  Also
:doc:`/python/dask-examples/index` for some examples.

There are two sets of APIs in this module, one is the functional API including
``train`` and ``predict`` methods.  Another is stateful Scikit-Learner wrapper
inherited from single-node Scikit-Learn interface.

The implementation is heavily influenced by dask_xgboost:
https://github.com/dask/dask-xgboost

Optional dask configuration
===========================

- **coll_cfg**:
    Specify the scheduler address along with communicator configurations. This can be
    used as a replacement of the existing global Dask configuration
    `xgboost.scheduler_address` (see below). See :ref:`tracker-ip` for more info. The
    `tracker_host_ip` should specify the IP address of the Dask scheduler node.

  .. versionadded:: 3.0.0

  .. code-block:: python

    from xgboost import dask as dxgb
    from xgboost.collective import Config

    coll_cfg = Config(
        retry=1, timeout=20, tracker_host_ip="10.23.170.98", tracker_port=0
    )

    clf = dxgb.DaskXGBClassifier(coll_cfg=coll_cfg)
    # or
    dxgb.train(client, {}, Xy, num_boost_round=10, coll_cfg=coll_cfg)

- **xgboost.scheduler_address**: Specify the scheduler address

  .. versionadded:: 1.6.0

  .. deprecated:: 3.0.0

  .. code-block:: python

      dask.config.set({"xgboost.scheduler_address": "192.0.0.100"})
      # We can also specify the port.
      dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"})

"""
import logging
from collections import defaultdict
from contextlib import contextmanager
from functools import partial, update_wrapper
from threading import Thread
from typing import (
    Any,
    Awaitable,
    Callable,
    Dict,
    Generator,
    Iterable,
    List,
    Optional,
    ParamSpec,
    Sequence,
    Set,
    Tuple,
    TypeAlias,
    TypedDict,
    TypeGuard,
    TypeVar,
    Union,
)

import dask
import distributed
import numpy
from dask import array as da
from dask import bag as db
from dask import dataframe as dd
from dask.delayed import Delayed
from distributed import Future

from .. import collective, config
from .._data_utils import Categories
from .._typing import FeatureNames, FeatureTypes, IterationRange
from ..callback import TrainingCallback
from ..collective import Config as CollConfig
from ..collective import _Args as CollArgs
from ..collective import _ArgVals as CollArgsVals
from ..compat import _is_cudf_df
from ..core import (
    Booster,
    DMatrix,
    Metric,
    Objective,
    XGBoostError,
    _check_distributed_params,
    _deprecate_positional_args,
    _expect,
)
from ..data import _is_cudf_ser, _is_cupy_alike
from ..sklearn import (
    XGBClassifier,
    XGBClassifierBase,
    XGBModel,
    XGBRanker,
    XGBRankerMixIn,
    XGBRegressorBase,
    _can_use_qdm,
    _check_rf_callback,
    _cls_predict_proba,
    _objective_decorator,
    _wrap_evaluation_matrices,
    xgboost_model_doc,
)
from ..tracker import RabitTracker
from ..training import train as worker_train
from .data import _get_dmatrices, no_group_split
from .utils import _DASK_2024_12_1, _DASK_2025_3_0, get_address_from_user, get_n_threads

_DaskCollection: TypeAlias = Union[da.Array, dd.DataFrame, dd.Series]
_DataT: TypeAlias = Union[da.Array, dd.DataFrame]  # do not use series as predictor
TrainReturnT = TypedDict(
    "TrainReturnT",
    {
        "booster": Booster,
        "history": Dict,
    },
)

__all__ = [
    "CommunicatorContext",
    "DaskDMatrix",
    "DaskQuantileDMatrix",
    "DaskXGBRegressor",
    "DaskXGBClassifier",
    "DaskXGBRanker",
    "DaskXGBRFRegressor",
    "DaskXGBRFClassifier",
    "train",
    "predict",
    "inplace_predict",
]

# TODOs:
#   - CV
#
# Note for developers:
#
#   As of writing asyncio is still a new feature of Python and in depth documentation is
#   rare.  Best examples of various asyncio tricks are in dask (luckily).  Classes like
#   Client, Worker are awaitable.  Some general rules for the implementation here:
#
#     - Synchronous world is different from asynchronous one, and they don't mix well.
#     - Write everything with async, then use distributed Client sync function to do the
#       switch.
#     - Use Any for type hint when the return value can be union of Awaitable and plain
#       value.  This is caused by Client.sync can return both types depending on
#       context.  Right now there's no good way to silent:
#
#         await train(...)
#
#       if train returns an Union type.


LOGGER = logging.getLogger("[xgboost.dask]")


def _try_start_tracker(
    n_workers: int,
    addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
    timeout: Optional[int],
) -> CollArgs:
    env: CollArgs = {}
    try:
        if isinstance(addrs[0], tuple):
            host_ip = addrs[0][0]
            port = addrs[0][1]
            rabit_tracker = RabitTracker(
                n_workers=n_workers,
                host_ip=host_ip,
                port=port,
                sortby="task",
                timeout=0 if timeout is None else timeout,
            )
        else:
            addr = addrs[0]
            assert isinstance(addr, str) or addr is None
            rabit_tracker = RabitTracker(
                n_workers=n_workers,
                host_ip=addr,
                sortby="task",
                timeout=0 if timeout is None else timeout,
            )

        rabit_tracker.start()
        # No timeout since we don't want to abort the training
        thread = Thread(target=rabit_tracker.wait_for)
        thread.daemon = True
        thread.start()
        env.update(rabit_tracker.worker_args())

    except XGBoostError as e:
        if len(addrs) < 2:
            raise
        LOGGER.warning(
            "Failed to bind address '%s', trying to use '%s' instead. Error:\n %s",
            str(addrs[0]),
            str(addrs[1]),
            str(e),
        )
        env = _try_start_tracker(n_workers, addrs[1:], timeout)

    return env


def _start_tracker(
    n_workers: int,
    addr_from_dask: Optional[str],
    addr_from_user: Optional[Tuple[str, int]],
    timeout: Optional[int],
) -> CollArgs:
    """Start Rabit tracker, recurse to try different addresses."""
    env = _try_start_tracker(n_workers, [addr_from_user, addr_from_dask], timeout)
    return env


class CommunicatorContext(collective.CommunicatorContext):
    """A context controlling collective communicator initialization and finalization."""

    def __init__(self, **args: CollArgsVals) -> None:
        super().__init__(**args)

        worker = distributed.get_worker()
        # We use task ID for rank assignment which makes the RABIT rank consistent (but
        # not the same as task ID is string and "10" is sorted before "2") with dask
        # worker name. This outsources the rank assignment to dask and prevents
        # non-deterministic issue.
        self.args["DMLC_TASK_ID"] = f"[xgboost.dask-{worker.name}]:{worker.address}"


def _get_client(client: Optional["distributed.Client"]) -> "distributed.Client":
    """Simple wrapper around testing None."""
    if not isinstance(client, (type(distributed.get_client()), type(None))):
        raise TypeError(
            _expect([type(distributed.get_client()), type(None)], type(client))
        )
    ret = distributed.get_client() if client is None else client
    return ret


# From the implementation point of view, DaskDMatrix complicates a lots of
# things.  A large portion of the code base is about syncing and extracting
# stuffs from DaskDMatrix.  But having an independent data structure gives us a
# chance to perform some specialized optimizations, like building histogram
# index directly.


class DaskDMatrix:
    # pylint: disable=too-many-instance-attributes
    """DMatrix holding on references to Dask DataFrame or Dask Array.  Constructing a
    `DaskDMatrix` forces all lazy computation to be carried out.  Wait for the input
    data explicitly if you want to see actual computation of constructing `DaskDMatrix`.

    See doc for :py:obj:`xgboost.DMatrix` constructor for other parameters.  DaskDMatrix
    accepts only dask collection.

    .. note::

        `DaskDMatrix` does not repartition or move data between workers.  It's the
        caller's responsibility to balance the data.

    .. note::

        For aligning partitions with ranking query groups, use the
        :py:class:`DaskXGBRanker` and its ``allow_group_split`` option.

    .. versionadded:: 1.0.0

    Parameters
    ----------
    client :
        Specify the dask client used for training.  Use default client returned from
        dask if it's set to None.

    """

    @_deprecate_positional_args
    def __init__(
        self,
        client: Optional["distributed.Client"],
        data: _DataT,
        label: Optional[_DaskCollection] = None,
        *,
        weight: Optional[_DaskCollection] = None,
        base_margin: Optional[_DaskCollection] = None,
        missing: Optional[float] = None,
        silent: bool = False,  # pylint: disable=unused-argument
        feature_names: Optional[FeatureNames] = None,
        feature_types: Optional[FeatureTypes] = None,
        group: Optional[_DaskCollection] = None,
        qid: Optional[_DaskCollection] = None,
        label_lower_bound: Optional[_DaskCollection] = None,
        label_upper_bound: Optional[_DaskCollection] = None,
        feature_weights: Optional[_DaskCollection] = None,
        enable_categorical: bool = False,
    ) -> None:
        client = _get_client(client)

        self.feature_names = feature_names
        self.feature_types = feature_types
        if isinstance(feature_types, Categories):
            raise TypeError(
                "The Dask interface can handle categories from DataFrame automatically."
            )
        self.missing = missing if missing is not None else numpy.nan
        self.enable_categorical = enable_categorical

        if qid is not None and weight is not None:
            raise NotImplementedError("per-group weight is not implemented.")
        if group is not None:
            raise NotImplementedError(
                "group structure is not implemented, use qid instead."
            )

        if len(data.shape) != 2:
            raise ValueError(f"Expecting 2 dimensional input, got: {data.shape}")

        if not isinstance(data, (dd.DataFrame, da.Array)):
            raise TypeError(_expect((dd.DataFrame, da.Array), type(data)))
        if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))):
            raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))

        self._n_cols = data.shape[1]
        assert isinstance(self._n_cols, int)
        self.worker_map: Dict[str, List[Future]] = defaultdict(list)
        self.is_quantile: bool = False

        self._init = client.sync(
            self._map_local_data,
            client=client,
            data=data,
            label=label,
            weights=weight,
            base_margin=base_margin,
            qid=qid,
            feature_weights=feature_weights,
            label_lower_bound=label_lower_bound,
            label_upper_bound=label_upper_bound,
        )

    def __await__(self) -> Generator[None, None, "DaskDMatrix"]:
        return self._init.__await__()

    async def _map_local_data(
        self,
        *,
        client: "distributed.Client",
        data: _DataT,
        label: Optional[_DaskCollection] = None,
        weights: Optional[_DaskCollection] = None,
        base_margin: Optional[_DaskCollection] = None,
        qid: Optional[_DaskCollection] = None,
        feature_weights: Optional[_DaskCollection] = None,
        label_lower_bound: Optional[_DaskCollection] = None,
        label_upper_bound: Optional[_DaskCollection] = None,
    ) -> "DaskDMatrix":
        """Obtain references to local data."""

        def inconsistent(
            left: List[Any], left_name: str, right: List[Any], right_name: str
        ) -> str:
            msg = (
                f"Partitions between {left_name} and {right_name} are not "
                f"consistent: {len(left)} != {len(right)}.  "
                f"Please try to repartition/rechunk your data."
            )
            return msg

        def to_futures(d: _DaskCollection) -> List[Future]:
            """Breaking data into partitions."""
            d = client.persist(d)
            if (
                hasattr(d.partitions, "shape")
                and len(d.partitions.shape) > 1
                and d.partitions.shape[1] > 1
            ):
                raise ValueError(
                    "Data should be"
                    " partitioned by row. To avoid this specify the number"
                    " of columns for your dask Array explicitly. e.g."
                    " chunks=(partition_size, -1])"
                )
            return client.futures_of(d)

        def flatten_meta(meta: Optional[_DaskCollection]) -> Optional[List[Future]]:
            if meta is not None:
                meta_parts: List[Future] = to_futures(meta)
                return meta_parts
            return None

        X_parts = to_futures(data)
        y_parts = flatten_meta(label)
        w_parts = flatten_meta(weights)
        margin_parts = flatten_meta(base_margin)
        qid_parts = flatten_meta(qid)
        ll_parts = flatten_meta(label_lower_bound)
        lu_parts = flatten_meta(label_upper_bound)

        parts: Dict[str, List[Future]] = {"data": X_parts}

        def append_meta(m_parts: Optional[List[Future]], name: str) -> None:
            if m_parts is not None:
                assert len(X_parts) == len(m_parts), inconsistent(
                    X_parts, "X", m_parts, name
                )
                parts[name] = m_parts

        append_meta(y_parts, "label")
        append_meta(w_parts, "weight")
        append_meta(margin_parts, "base_margin")
        append_meta(qid_parts, "qid")
        append_meta(ll_parts, "label_lower_bound")
        append_meta(lu_parts, "label_upper_bound")
        # At this point, `parts` looks like:
        # [(x0, x1, ..), (y0, y1, ..), ..] in future form

        # turn into list of dictionaries.
        packed_parts: List[Dict[str, Future]] = []
        for i in range(len(X_parts)):
            part_dict: Dict[str, Future] = {}
            for key, value in parts.items():
                part_dict[key] = value[i]
            packed_parts.append(part_dict)

        # delay the zipped result
        # pylint: disable=no-member
        delayed_parts: List[Delayed] = list(map(dask.delayed, packed_parts))
        # At this point, the mental model should look like:
        # [{"data": x0, "label": y0, ..}, {"data": x1, "label": y1, ..}, ..]

        # Convert delayed objects into futures and make sure they are realized
        #
        # This also makes partitions to align (co-locate) on workers (X_0, y_0 should be
        # on the same worker).
        fut_parts: List[Future] = client.compute(delayed_parts)
        await distributed.wait(fut_parts)  # async wait for parts to be computed

        for part in fut_parts:
            # Each part is [{"data": x0, "label": y0, ..}, ...] in future form.
            assert part.status == "finished", part.status

        # Preserving the partition order for prediction.
        self.partition_order = {}
        for i, part in enumerate(fut_parts):
            self.partition_order[part.key] = i

        key_to_partition = {part.key: part for part in fut_parts}
        who_has: Dict[str, Tuple[str, ...]] = await client.scheduler.who_has(
            keys=[part.key for part in fut_parts]
        )

        worker_map: Dict[str, List[Future]] = defaultdict(list)

        for key, workers in who_has.items():
            worker_map[next(iter(workers))].append(key_to_partition[key])

        self.worker_map = worker_map

        if feature_weights is None:
            self.feature_weights = None
        else:
            self.feature_weights = await client.compute(feature_weights).result()

        return self

    def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
        """Create a dictionary of objects that can be pickled for function
        arguments.

        """
        return {
            "feature_names": self.feature_names,
            "feature_types": self.feature_types,
            "feature_weights": self.feature_weights,
            "missing": self.missing,
            "enable_categorical": self.enable_categorical,
            "parts": self.worker_map.get(worker_addr, None),
            "is_quantile": self.is_quantile,
        }

    def num_col(self) -> int:
        """Get the number of columns (features) in the DMatrix.

        Returns
        -------
        number of columns
        """
        return self._n_cols


_MapRetT = TypeVar("_MapRetT")
_P = ParamSpec("_P")


async def map_worker_partitions(
    client: Optional["distributed.Client"],
    func: Callable[_P, _MapRetT],
    *refs: Any,
    workers: Sequence[str],
) -> _MapRetT:
    """Map a function onto partitions of each worker."""
    # Note for function purity:
    # XGBoost is sensitive to data partition and uses random number generator.
    client = _get_client(client)
    futures = []
    for addr in workers:
        args = []
        for ref in refs:
            if isinstance(ref, DaskDMatrix):
                # pylint: disable=protected-access
                args.append(ref._create_fn_args(addr))
            else:
                args.append(ref)

        def fn(_address: str, *args: _P.args, **kwargs: _P.kwargs) -> List[_MapRetT]:
            worker = distributed.get_worker()

            if worker.address != _address:
                raise ValueError(
                    f"Invalid worker address: {worker.address}, expecting {_address}. "
                    "This is likely caused by one of the workers died and Dask "
                    "re-scheduled a different one. Resilience is not yet supported."
                )
            # Turn result into a list for bag construction
            return [func(*args, **kwargs)]

        # XGBoost requires all workers running training tasks to be unique. Meaning, we
        # can't run 2 training jobs on the same node. This at best leads to an error
        # (NCCL unique check), at worst leads to extremely slow training performance
        # without any warning.
        #
        # See disitributed.scheduler.decide_worker for `allow_other_workers`. In
        # summary, the scheduler chooses a worker from the valid set that has the task
        # dependencies. Each XGBoost's training task has all dependencies in a single
        # worker. As a result, the right worker should be picked by the scheduler even
        # if `allow_other_workers` is set to True.
        #
        # In addition, the scheduler only discards the valid set (the `workers` arg) if
        # there's no candidate can be found. This is likely caused by killed workers. In
        # that case, the check in `fn` should be able to stop the task. If we don't
        # relax the constraint and prevent Dask from choosing an invalid worker, the
        # task will simply hangs. We prefer a quick error here.
        #
        fut = client.submit(
            update_wrapper(partial(fn, addr), fn),
            *args,
            pure=False,
            workers=[addr],
            allow_other_workers=True,
        )
        futures.append(fut)

    def first_valid(results: Iterable[Optional[_MapRetT]]) -> Optional[_MapRetT]:
        for v in results:
            if v is not None:
                return v
        return None

    bag = db.from_delayed(futures)
    fut = await bag.reduction(first_valid, first_valid)
    result = await client.compute(fut).result()

    return result


class DaskQuantileDMatrix(DaskDMatrix):
    """A dask version of :py:class:`QuantileDMatrix`. See :py:class:`DaskDMatrix` for
    parameter documents.

    """

    @_deprecate_positional_args
    def __init__(
        self,
        client: Optional["distributed.Client"],
        data: _DataT,
        label: Optional[_DaskCollection] = None,
        *,
        weight: Optional[_DaskCollection] = None,
        base_margin: Optional[_DaskCollection] = None,
        missing: Optional[float] = None,
        silent: bool = False,  # disable=unused-argument
        feature_names: Optional[FeatureNames] = None,
        feature_types: Optional[Union[Any, List[Any]]] = None,
        max_bin: Optional[int] = None,
        ref: Optional[DaskDMatrix] = None,
        group: Optional[_DaskCollection] = None,
        qid: Optional[_DaskCollection] = None,
        label_lower_bound: Optional[_DaskCollection] = None,
        label_upper_bound: Optional[_DaskCollection] = None,
        feature_weights: Optional[_DaskCollection] = None,
        enable_categorical: bool = False,
        max_quantile_batches: Optional[int] = None,
    ) -> None:
        super().__init__(
            client=client,
            data=data,
            label=label,
            weight=weight,
            base_margin=base_margin,
            group=group,
            qid=qid,
            label_lower_bound=label_lower_bound,
            label_upper_bound=label_upper_bound,
            missing=missing,
            silent=silent,
            feature_weights=feature_weights,
            feature_names=feature_names,
            feature_types=feature_types,
            enable_categorical=enable_categorical,
        )
        self.max_bin = max_bin
        self.max_quantile_batches = max_quantile_batches
        self.is_quantile = True
        self._ref: Optional[int] = id(ref) if ref is not None else None

    def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]:
        args = super()._create_fn_args(worker_addr)
        args["max_bin"] = self.max_bin
        args["max_quantile_batches"] = self.max_quantile_batches
        if self._ref is not None:
            args["ref"] = self._ref
        return args


async def _get_rabit_args(
    client: "distributed.Client",
    n_workers: int,
    dconfig: Optional[Dict[str, Any]] = None,
    coll_cfg: Optional[CollConfig] = None,
) -> Dict[str, Union[str, int]]:
    """Get rabit context arguments from data distribution in DaskDMatrix."""
    # There are 3 possible different addresses:
    # 1. Provided by user via dask.config
    # 2. Guessed by xgboost `get_host_ip` function
    # 3. From dask scheduler
    # We try 1 and 3 if 1 is available, otherwise 2 and 3.

    # See if user config is available
    coll_cfg = CollConfig() if coll_cfg is None else coll_cfg
    host_ip: Optional[str] = None
    port: int = 0
    host_ip, port = get_address_from_user(dconfig, coll_cfg)

    if host_ip is not None:
        user_addr = (host_ip, port)
    else:
        user_addr = None

    # Try address from dask scheduler, this might not work, see
    # https://github.com/dask/dask-xgboost/pull/40
    try:
        sched_addr = distributed.comm.get_address_host(client.scheduler.address)
        sched_addr = sched_addr.strip("/:")
    except Exception:  # pylint: disable=broad-except
        sched_addr = None

    # We assume the scheduler is a fair process and run the tracker there.
    env = await client.run_on_scheduler(
        _start_tracker, n_workers, sched_addr, user_addr, coll_cfg.tracker_timeout
    )
    env = coll_cfg.get_comm_config(env)
    assert env is not None
    return env


def _get_dask_config() -> Optional[Dict[str, Any]]:
    return dask.config.get("xgboost", default=None)


# train and predict methods are supposed to be "functional", which meets the
# dask paradigm.  But as a side effect, the `evals_result` in single-node API
# is no longer supported since it mutates the input parameter, and it's not
# intuitive to sync the mutation result.  Therefore, a dictionary containing
# evaluation history is instead returned.


def _get_workers_from_data(
    dtrain: DaskDMatrix, evals: Optional[Sequence[Tuple[DaskDMatrix, str]]]
) -> List[str]:
    X_worker_map: Set[str] = set(dtrain.worker_map.keys())
    if evals:
        for e in evals:
            assert len(e) == 2
            assert isinstance(e[0], DaskDMatrix) and isinstance(e[1], str)
            if e[0] is dtrain:
                continue
            worker_map = set(e[0].worker_map.keys())
            X_worker_map = X_worker_map.union(worker_map)
    return list(X_worker_map)


async def _check_workers_are_alive(
    workers: List[str], client: "distributed.Client"
) -> None:
    info = await client.scheduler.identity()
    current_workers = info["workers"].keys()
    missing_workers = set(workers) - current_workers
    if missing_workers:
        raise RuntimeError(f"Missing required workers: {missing_workers}")


async def _train_async(
    *,
    client: "distributed.Client",
    global_config: Dict[str, Any],
    dconfig: Optional[Dict[str, Any]],
    params: Dict[str, Any],
    dtrain: DaskDMatrix,
    num_boost_round: int,
    evals: Optional[Sequence[Tuple[DaskDMatrix, str]]],
    obj: Optional[Objective],
    early_stopping_rounds: Optional[int],
    verbose_eval: Union[int, bool],
    xgb_model: Optional[Booster],
    callbacks: Optional[Sequence[TrainingCallback]],
    custom_metric: Optional[Metric],
    coll_cfg: Optional[CollConfig],
) -> Optional[TrainReturnT]:
    workers = _get_workers_from_data(dtrain, evals)
    await _check_workers_are_alive(workers, client)
    coll_args = await _get_rabit_args(
        client, len(workers), dconfig=dconfig, coll_cfg=coll_cfg
    )
    _check_distributed_params(params)

    # This function name is displayed in the Dask dashboard task status, let's make it
    # clear that it's XGBoost training.
    def do_train(  # pylint: disable=too-many-positional-arguments
        parameters: Dict,
        coll_args: Dict[str, Union[str, int]],
        train_id: int,
        evals_name: List[str],
        evals_id: List[int],
        train_ref: dict,
        *refs: dict,
    ) -> Optional[TrainReturnT]:
        worker = distributed.get_worker()
        local_param = parameters.copy()
        n_threads = get_n_threads(local_param, worker)
        local_param.update({"nthread": n_threads, "n_jobs": n_threads})

        local_history: TrainingCallback.EvalsLog = {}
        global_config.update({"nthread": n_threads})

        with CommunicatorContext(**coll_args), config.config_context(**global_config):
            Xy, evals = _get_dmatrices(
                train_ref,
                train_id,
                *refs,
                evals_id=evals_id,
                evals_name=evals_name,
                n_threads=n_threads,
                # We need the model for reference categories.
                model=xgb_model,
            )

            booster = worker_train(
                params=local_param,
                dtrain=Xy,
                num_boost_round=num_boost_round,
                evals_result=local_history,
                evals=evals if len(evals) != 0 else None,
                obj=obj,
                custom_metric=custom_metric,
                early_stopping_rounds=early_stopping_rounds,
                verbose_eval=verbose_eval,
                xgb_model=xgb_model,
                callbacks=callbacks,
            )
        # Don't return the boosters from empty workers. It's quite difficult to
        # guarantee everything is in sync in the present of empty workers, especially
        # with complex objectives like quantile.
        if Xy.num_row() != 0:
            ret: Optional[TrainReturnT] = {
                "booster": booster,
                "history": local_history,
            }
        else:
            ret = None
        return ret

    async with distributed.MultiLock(workers, client):
        if evals is not None:
            evals_data = [d for d, n in evals]
            evals_name = [n for d, n in evals]
            evals_id = [id(d) for d in evals_data]
        else:
            evals_data = []
            evals_name = []
            evals_id = []

        result = await map_worker_partitions(
            client,
            do_train,
            # extra function parameters
            params,
            coll_args,
            id(dtrain),
            evals_name,
            evals_id,
            *([dtrain] + evals_data),
            # workers to be used for training
            workers=workers,
        )
        return result


@_deprecate_positional_args
def train(  # pylint: disable=unused-argument
    client: "distributed.Client",
    params: Dict[str, Any],
    dtrain: DaskDMatrix,
    num_boost_round: int = 10,
    *,
    evals: Optional[Sequence[Tuple[DaskDMatrix, str]]] = None,
    obj: Optional[Objective] = None,
    early_stopping_rounds: Optional[int] = None,
    xgb_model: Optional[Booster] = None,
    verbose_eval: Union[int, bool] = True,
    callbacks: Optional[Sequence[TrainingCallback]] = None,
    custom_metric: Optional[Metric] = None,
    coll_cfg: Optional[CollConfig] = None,
) -> Any:
    """Train XGBoost model.

    .. versionadded:: 1.0.0

    .. note::

        Other parameters are the same as :py:func:`xgboost.train` except for
        `evals_result`, which is returned as part of function return value instead of
        argument.

    Parameters
    ----------
    client :
        Specify the dask client used for training.  Use default client returned from
        dask if it's set to None.

    coll_cfg :
        Configuration for the communicator used during training. See
        :py:class:`~xgboost.collective.Config`.

    Returns
    -------
    results: dict
        A dictionary containing trained booster and evaluation history.  `history` field
        is the same as `eval_result` from `xgboost.train`.

        .. code-block:: python

            {'booster': xgboost.Booster,
             'history': {'train': {'logloss': ['0.48253', '0.35953']},
                         'eval': {'logloss': ['0.480385', '0.357756']}}}

    """
    client = _get_client(client)
    return client.sync(
        _train_async,
        global_config=config.get_config(),
        dconfig=_get_dask_config(),
        **locals(),
    )


def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
    return is_df and len(output_shape) <= 2


def _maybe_dataframe(
    data: Any, prediction: Any, columns: List[int], is_df: bool
) -> Any:
    """Return dataframe for prediction when applicable."""
    if _can_output_df(is_df, prediction.shape):
        # Need to preserve the index for dataframe.
        # See issue: https://github.com/dmlc/xgboost/issues/6939
        # In older versions of dask, the partition is actually a numpy array when input
        # is dataframe.
        index = getattr(data, "index", None)
        if _is_cudf_df(data):
            import cudf

            if prediction.size == 0:
                return cudf.DataFrame({}, columns=columns, dtype=numpy.float32)

            prediction = cudf.DataFrame(
                prediction, columns=columns, dtype=numpy.float32, index=index
            )
        else:
            import pandas as pd

            if prediction.size == 0:
                return pd.DataFrame(
                    {}, columns=columns, dtype=numpy.float32, index=index
                )

            prediction = pd.DataFrame(
                prediction, columns=columns, dtype=numpy.float32, index=index
            )
    return prediction


async def _direct_predict_impl(  # pylint: disable=too-many-branches
    *,
    mapped_predict: Callable,
    booster: "distributed.Future",
    data: _DataT,
    base_margin: Optional[_DaskCollection],
    output_shape: Tuple[int, ...],
    meta: Dict[int, str],
) -> _DaskCollection:
    columns = tuple(meta.keys())
    if len(output_shape) >= 3 and isinstance(data, dd.DataFrame):
        # Without this check, dask will finish the prediction silently even if output
        # dimension is greater than 3.  But during map_partitions, dask passes a
        # `dd.DataFrame` as local input to xgboost, which is converted to csr_matrix by
        # `_convert_unknown_data` since dd.DataFrame is not known to xgboost native
        # binding.
        raise ValueError(
            "Use `da.Array` or `DaskDMatrix` when output has more than 2 dimensions."
        )
    if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
        if base_margin is not None and isinstance(base_margin, da.Array):
            # Easier for map_partitions
            base_margin_df: Optional[Union[dd.DataFrame, dd.Series]] = (
                base_margin.to_dask_dataframe()
            )
        else:
            base_margin_df = base_margin
        predictions = dd.map_partitions(
            mapped_predict,
            booster,
            data,
            True,
            columns,
            base_margin_df,
            meta=dd.utils.make_meta(meta),
        )
        # classification can return a dataframe, drop 1 dim when it's reg/binary
        if len(output_shape) == 1:
            predictions = predictions.iloc[:, 0]
    else:
        if base_margin is not None and isinstance(
            base_margin, (dd.Series, dd.DataFrame)
        ):
            # Easier for map_blocks
            base_margin_array: Optional[da.Array] = base_margin.to_dask_array()
        else:
            base_margin_array = base_margin
        # Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class,
        # contrib)/3(contrib, interaction)/4(interaction) dims.
        if len(output_shape) == 1:
            drop_axis: Union[int, List[int]] = [1]  # drop from 2 to 1 dim.
            new_axis: Union[int, List[int]] = []
        else:
            drop_axis = []
            if isinstance(data, dd.DataFrame):
                new_axis = list(range(len(output_shape) - 2))
            else:
                new_axis = [i + 2 for i in range(len(output_shape) - 2)]
        if len(output_shape) == 2:
            # Somehow dask fail to infer output shape change for 2-dim prediction, and
            #  `chunks = (None, output_shape[1])` doesn't work due to None is not
            #  supported in map_blocks.

            # data must be an array here as dataframe + 2-dim output predict will return
            # a dataframe instead.
            chunks: Optional[List[Tuple]] = list(data.chunks)
            assert isinstance(chunks, list)
            chunks[1] = (output_shape[1],)
        else:
            chunks = None
        predictions = da.map_blocks(
            mapped_predict,
            booster,
            data,
            False,
            columns,
            base_margin_array,
            chunks=chunks,
            drop_axis=drop_axis,
            new_axis=new_axis,
            dtype=numpy.float32,
        )
    return predictions


def _infer_predict_output(
    booster: Booster, features: int, is_df: bool, inplace: bool, **kwargs: Any
) -> Tuple[Tuple[int, ...], Dict[int, str]]:
    """Create a dummy test sample to infer output shape for prediction."""
    assert isinstance(features, int)
    rng = numpy.random.RandomState(1994)
    test_sample = rng.randn(1, features)
    if inplace:
        kwargs = kwargs.copy()
        if kwargs.pop("predict_type") == "margin":
            kwargs["output_margin"] = True
    m = DMatrix(test_sample, enable_categorical=True)
    # generated DMatrix doesn't have feature name, so no validation.
    test_predt = booster.predict(m, validate_features=False, **kwargs)
    n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
    meta: Dict[int, str] = {}
    if _can_output_df(is_df, test_predt.shape):
        for i in range(n_columns):
            meta[i] = "f4"
    return test_predt.shape, meta


async def _get_model_future(
    client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
) -> "distributed.Future":
    # See https://github.com/dask/dask/issues/11179#issuecomment-2168094529 for the use
    # of hash.
    # https://github.com/dask/distributed/pull/8796 Don't use broadcast in the `scatter`
    # call, otherwise, the predict function might hang.
    if isinstance(model, Booster):
        booster = await client.scatter(model, hash=False)
    elif isinstance(model, dict):
        booster = await client.scatter(model["booster"], hash=False)
    elif isinstance(model, distributed.Future):
        booster = model
        t = booster.type
        if t is not Booster:
            raise TypeError(
                f"Underlying type of model future should be `Booster`, got {t}"
            )
    else:
        raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
    return booster


# pylint: disable=too-many-statements
async def _predict_async(
    client: "distributed.Client",
    global_config: Dict[str, Any],
    model: Union[Booster, Dict, "distributed.Future"],
    data: _DataT,
    *,
    output_margin: bool,
    missing: float,
    pred_leaf: bool,
    pred_contribs: bool,
    approx_contribs: bool,
    pred_interactions: bool,
    validate_features: bool,
    iteration_range: IterationRange,
    strict_shape: bool,
) -> _DaskCollection:
    _booster = await _get_model_future(client, model)
    if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
        raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))

    def mapped_predict(
        booster: Booster, partition: Any, is_df: bool, columns: List[int], _: Any
    ) -> Any:
        with config.config_context(**global_config):
            m = DMatrix(
                data=partition,
                missing=missing,
                enable_categorical=True,
            )
            predt = booster.predict(
                data=m,
                output_margin=output_margin,
                pred_leaf=pred_leaf,
                pred_contribs=pred_contribs,
                approx_contribs=approx_contribs,
                pred_interactions=pred_interactions,
                validate_features=validate_features,
                iteration_range=iteration_range,
                strict_shape=strict_shape,
            )
            predt = _maybe_dataframe(partition, predt, columns, is_df)
            return predt

    # Predict on dask collection directly.
    if isinstance(data, (da.Array, dd.DataFrame)):
        _output_shape, meta = await client.compute(
            client.submit(
                _infer_predict_output,
                _booster,
                features=data.shape[1],
                is_df=isinstance(data, dd.DataFrame),
                inplace=False,
                output_margin=output_margin,
                pred_leaf=pred_leaf,
                pred_contribs=pred_contribs,
                approx_contribs=approx_contribs,
                pred_interactions=pred_interactions,
                strict_shape=strict_shape,
            )
        )
        return await _direct_predict_impl(
            mapped_predict=mapped_predict,
            booster=_booster,
            data=data,
            base_margin=None,
            output_shape=_output_shape,
            meta=meta,
        )

    output_shape, _ = await client.compute(
        client.submit(
            _infer_predict_output,
            booster=_booster,
            features=data.num_col(),
            is_df=False,
            inplace=False,
            output_margin=output_margin,
            pred_leaf=pred_leaf,
            pred_contribs=pred_contribs,
            approx_contribs=approx_contribs,
            pred_interactions=pred_interactions,
            strict_shape=strict_shape,
        )
    )
    # Prediction on dask DMatrix.
    partition_order = data.partition_order
    feature_names = data.feature_names
    feature_types = data.feature_types
    missing = data.missing

    def dispatched_predict(booster: Booster, part: Dict[str, Any]) -> numpy.ndarray:
        data = part["data"]
        base_margin = part.get("base_margin", None)
        with config.config_context(**global_config):
            m = DMatrix(
                data,
                missing=missing,
                base_margin=base_margin,
                feature_names=feature_names,
                feature_types=feature_types,
                enable_categorical=True,
            )
            predt = booster.predict(
                m,
                output_margin=output_margin,
                pred_leaf=pred_leaf,
                pred_contribs=pred_contribs,
                approx_contribs=approx_contribs,
                pred_interactions=pred_interactions,
                validate_features=validate_features,
                iteration_range=iteration_range,
                strict_shape=strict_shape,
            )
            return predt

    all_parts = []
    all_orders = []
    all_shapes = []
    all_workers: List[str] = []
    workers_address = list(data.worker_map.keys())
    for worker_addr in workers_address:
        list_of_parts = data.worker_map[worker_addr]
        all_parts.extend(list_of_parts)
        all_workers.extend(len(list_of_parts) * [worker_addr])
        all_orders.extend([partition_order[part.key] for part in list_of_parts])
    for w, part in zip(all_workers, all_parts):
        s = client.submit(lambda part: part["data"].shape[0], part, workers=[w])
        all_shapes.append(s)

    parts_with_order = list(zip(all_parts, all_shapes, all_orders, all_workers))
    parts_with_order = sorted(parts_with_order, key=lambda p: p[2])
    all_parts = [part for part, shape, order, w in parts_with_order]
    all_shapes = [shape for part, shape, order, w in parts_with_order]
    all_workers = [w for part, shape, order, w in parts_with_order]

    futures = []
    for w, part in zip(all_workers, all_parts):
        f = client.submit(dispatched_predict, _booster, part, workers=[w])
        futures.append(f)

    # Constructing a dask array from list of numpy arrays
    # See https://docs.dask.org/en/latest/array-creation.html
    arrays = []
    all_shapes = await client.gather(all_shapes)
    for i, rows in enumerate(all_shapes):
        arrays.append(
            da.from_delayed(
                futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32
            )
        )
    predictions = da.concatenate(arrays, axis=0)
    return predictions


@_deprecate_positional_args
def predict(  # pylint: disable=unused-argument
    client: Optional["distributed.Client"],
    model: Union[TrainReturnT, Booster, "distributed.Future"],
    data: Union[DaskDMatrix, _DataT],
    *,
    output_margin: bool = False,
    missing: float = numpy.nan,
    pred_leaf: bool = False,
    pred_contribs: bool = False,
    approx_contribs: bool = False,
    pred_interactions: bool = False,
    validate_features: bool = True,
    iteration_range: IterationRange = (0, 0),
    strict_shape: bool = False,
) -> Any:
    """Run prediction with a trained booster.

    .. note::

        Using ``inplace_predict`` might be faster when some features are not needed.
        See :py:meth:`xgboost.Booster.predict` for details on various parameters.  When
        output has more than 2 dimensions (shap value, leaf with strict_shape), input
        should be ``da.Array`` or ``DaskDMatrix``.

    .. versionadded:: 1.0.0

    Parameters
    ----------
    client:
        Specify the dask client used for training.  Use default client
        returned from dask if it's set to None.
    model:
        The trained model.  It can be a distributed.Future so user can
        pre-scatter it onto all workers.
    data:
        Input data used for prediction.  When input is a dataframe object,
        prediction output is a series.
    missing:
        Used when input data is not DaskDMatrix.  Specify the value
        considered as missing.

    Returns
    -------
    prediction: dask.array.Array/dask.dataframe.Series
        When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is
        an array, when input data is ``dask.dataframe.DataFrame``, return value can be
        ``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
        shape.

    """
    client = _get_client(client)
    return client.sync(_predict_async, global_config=config.get_config(), **locals())


async def _inplace_predict_async(  # pylint: disable=too-many-branches
    *,
    client: "distributed.Client",
    global_config: Dict[str, Any],
    model: Union[Booster, Dict, "distributed.Future"],
    data: _DataT,
    iteration_range: IterationRange,
    predict_type: str,
    missing: float,
    validate_features: bool,
    base_margin: Optional[_DaskCollection],
    strict_shape: bool,
) -> _DaskCollection:
    client = _get_client(client)
    booster = await _get_model_future(client, model)
    if not isinstance(data, (da.Array, dd.DataFrame)):
        raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
    if base_margin is not None and not isinstance(
        data, (da.Array, dd.DataFrame, dd.Series)
    ):
        raise TypeError(_expect([da.Array, dd.DataFrame, dd.Series], type(base_margin)))

    def mapped_predict(
        booster: Booster,
        partition: Any,
        is_df: bool,
        columns: List[int],
        base_margin: Any,
    ) -> Any:
        with config.config_context(**global_config):
            prediction = booster.inplace_predict(
                partition,
                iteration_range=iteration_range,
                predict_type=predict_type,
                missing=missing,
                base_margin=base_margin,
                validate_features=validate_features,
                strict_shape=strict_shape,
            )
        prediction = _maybe_dataframe(partition, prediction, columns, is_df)
        return prediction

    # await turns future into value.
    shape, meta = await client.compute(
        client.submit(
            _infer_predict_output,
            booster,
            features=data.shape[1],
            is_df=isinstance(data, dd.DataFrame),
            inplace=True,
            predict_type=predict_type,
            iteration_range=iteration_range,
            strict_shape=strict_shape,
        )
    )
    return await _direct_predict_impl(
        mapped_predict=mapped_predict,
        booster=booster,
        data=data,
        base_margin=base_margin,
        output_shape=shape,
        meta=meta,
    )


@_deprecate_positional_args
def inplace_predict(  # pylint: disable=unused-argument
    client: Optional["distributed.Client"],
    model: Union[TrainReturnT, Booster, "distributed.Future"],
    data: _DataT,
    *,
    iteration_range: IterationRange = (0, 0),
    predict_type: str = "value",
    missing: float = numpy.nan,
    validate_features: bool = True,
    base_margin: Optional[_DaskCollection] = None,
    strict_shape: bool = False,
) -> Any:
    """Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for
    details.

    .. versionadded:: 1.1.0

    Parameters
    ----------
    client:
        Specify the dask client used for training.  Use default client
        returned from dask if it's set to None.
    model:
        See :py:func:`xgboost.dask.predict` for details.
    data :
        dask collection.
    iteration_range:
        See :py:meth:`xgboost.Booster.predict` for details.
    predict_type:
        See :py:meth:`xgboost.Booster.inplace_predict` for details.
    missing:
        Value in the input data which needs to be present as a missing
        value. If None, defaults to np.nan.
    base_margin:
        See :py:obj:`xgboost.DMatrix` for details.

        .. versionadded:: 1.4.0

    strict_shape:
        See :py:meth:`xgboost.Booster.predict` for details.

        .. versionadded:: 1.4.0

    Returns
    -------
    prediction :
        When input data is ``dask.array.Array``, the return value is an array, when
        input data is ``dask.dataframe.DataFrame``, return value can be
        ``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
        shape.

    """
    client = _get_client(client)
    # When used in asynchronous environment, the `client` object should have
    # `asynchronous` attribute as True.  When invoked by the skl interface, it's
    # responsible for setting up the client.
    return client.sync(
        _inplace_predict_async, global_config=config.get_config(), **locals()
    )


async def _async_wrap_evaluation_matrices(
    client: Optional["distributed.Client"],
    device: Optional[str],
    tree_method: Optional[str],
    max_bin: Optional[int],
    **kwargs: Any,
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
    """A switch function for async environment."""

    def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
        if _can_use_qdm(tree_method, device):
            return DaskQuantileDMatrix(
                client=client, ref=ref, max_bin=max_bin, **kwargs
            )
        return DaskDMatrix(client=client, **kwargs)

    train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_dispatch, **kwargs)
    train_dmatrix = await train_dmatrix
    if evals is None:
        return train_dmatrix, evals
    awaited = []
    for e in evals:
        if e[0] is train_dmatrix:  # already awaited
            awaited.append(e)
            continue
        awaited.append((await e[0], e[1]))
    return train_dmatrix, awaited


@contextmanager
def _set_worker_client(
    model: "DaskScikitLearnBase", client: "distributed.Client"
) -> Generator:
    """Temporarily set the client for sklearn model."""
    try:
        model.client = client
        yield model
    finally:
        model.client = None  # type:ignore


class DaskScikitLearnBase(XGBModel):
    """Base class for implementing scikit-learn interface with Dask"""

    _client = None

    def __init__(self, *, coll_cfg: Optional[CollConfig] = None, **kwargs: Any) -> None:
        super().__init__(**kwargs)

        self.coll_cfg = coll_cfg

    async def _predict_async(
        self,
        data: _DataT,
        *,
        output_margin: bool,
        validate_features: bool,
        base_margin: Optional[_DaskCollection],
        iteration_range: Optional[IterationRange],
    ) -> Any:
        iteration_range = self._get_iteration_range(iteration_range)
        # Dask doesn't support gblinear and accepts only Dask collection types (array
        # and dataframe). We can perform inplace predict.
        assert self._can_use_inplace_predict()
        predts = await inplace_predict(
            client=self.client,
            model=self.get_booster(),
            data=data,
            iteration_range=iteration_range,
            predict_type="margin" if output_margin else "value",
            missing=self.missing,
            base_margin=base_margin,
            validate_features=validate_features,
        )
        if isinstance(predts, dd.DataFrame):
            predts = predts.to_dask_array()
            # Make sure the booster is part of the task graph implicitly
            # only needed for certain versions of dask.
            if _DASK_2024_12_1() and not _DASK_2025_3_0():
                # Fixes this issue for dask>=2024.1.1,<2025.3.0
                # Dask==2025.3.0 fails with:
                #     RuntimeError: Attempting to use an asynchronous
                #     Client in a synchronous context of `dask.compute`
                #
                # Dask==2025.4.0 fails with:
                #     TypeError: Value type is not supported for data
                #     iterator:<class 'distributed.client.Future'>
                predts = predts.persist()
        return predts

    @_deprecate_positional_args
    def predict(
        self,
        X: _DataT,
        *,
        output_margin: bool = False,
        validate_features: bool = True,
        base_margin: Optional[_DaskCollection] = None,
        iteration_range: Optional[IterationRange] = None,
    ) -> Any:
        return self.client.sync(
            self._predict_async,
            X,
            output_margin=output_margin,
            validate_features=validate_features,
            base_margin=base_margin,
            iteration_range=iteration_range,
        )

    async def _apply_async(
        self,
        X: _DataT,
        iteration_range: Optional[IterationRange] = None,
    ) -> Any:
        iteration_range = self._get_iteration_range(iteration_range)
        test_dmatrix: DaskDMatrix = await DaskDMatrix(
            self.client,
            data=X,
            missing=self.missing,
            feature_types=self.feature_types,
        )
        predts = await predict(
            self.client,
            model=self.get_booster(),
            data=test_dmatrix,
            pred_leaf=True,
            iteration_range=iteration_range,
        )
        return predts

    def apply(
        self,
        X: _DataT,
        iteration_range: Optional[IterationRange] = None,
    ) -> Any:
        return self.client.sync(self._apply_async, X, iteration_range=iteration_range)

    def __await__(self) -> Awaitable[Any]:
        # Generate a coroutine wrapper to make this class awaitable.
        async def _() -> Awaitable[Any]:
            return self

        return self._client_sync(_).__await__()

    def __getstate__(self) -> Dict:
        this = self.__dict__.copy()
        if "_client" in this:
            del this["_client"]
        return this

    @property
    def client(self) -> "distributed.Client":
        """The dask client used in this model.  The `Client` object can not be
        serialized for transmission, so if task is launched from a worker instead of
        directly from the client process, this attribute needs to be set at that worker.

        """

        client = _get_client(self._client)
        return client

    @client.setter
    def client(self, clt: "distributed.Client") -> None:
        # calling `worker_client' doesn't return the correct `asynchronous` attribute,
        # so we have to pass it ourselves.
        self._asynchronous = clt.asynchronous if clt is not None else False
        self._client = clt

    def _client_sync(self, func: Callable, **kwargs: Any) -> Any:
        """Get the correct client, when method is invoked inside a worker we
        should use `worker_client' instead of default client.

        """

        if self._client is None:
            asynchronous = getattr(self, "_asynchronous", False)
            try:
                distributed.get_worker()
                in_worker = True
            except ValueError:
                in_worker = False
            if in_worker:
                with distributed.worker_client() as client:
                    with _set_worker_client(self, client) as this:
                        ret = this.client.sync(
                            func, **kwargs, asynchronous=asynchronous
                        )
                        return ret
                    return ret

        return self.client.sync(func, **kwargs, asynchronous=self.client.asynchronous)


@xgboost_model_doc(
    """Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
)
class DaskXGBRegressor(XGBRegressorBase, DaskScikitLearnBase):
    """dummy doc string to workaround pylint, replaced by the decorator."""

    async def _fit_async(
        self,
        X: _DataT,
        y: _DaskCollection,
        *,
        sample_weight: Optional[_DaskCollection],
        base_margin: Optional[_DaskCollection],
        eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
        sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
        base_margin_eval_set: Optional[Sequence[_DaskCollection]],
        verbose: Union[int, bool],
        xgb_model: Optional[Union[Booster, XGBModel]],
        feature_weights: Optional[_DaskCollection],
    ) -> _DaskCollection:
        params = self.get_xgb_params()
        model, metric, params, feature_weights = self._configure_fit(
            xgb_model, params, feature_weights
        )

        dtrain, evals = await _async_wrap_evaluation_matrices(
            client=self.client,
            device=self.device,
            tree_method=self.tree_method,
            max_bin=self.max_bin,
            X=X,
            y=y,
            group=None,
            qid=None,
            sample_weight=sample_weight,
            base_margin=base_margin,
            feature_weights=feature_weights,
            eval_set=eval_set,
            sample_weight_eval_set=sample_weight_eval_set,
            base_margin_eval_set=base_margin_eval_set,
            eval_group=None,
            eval_qid=None,
            missing=self.missing,
            enable_categorical=self.enable_categorical,
            feature_types=self.feature_types,
        )

        if callable(self.objective):
            obj: Optional[Callable] = _objective_decorator(self.objective)
        else:
            obj = None
        results = await self.client.sync(
            _train_async,
            asynchronous=True,
            client=self.client,
            global_config=config.get_config(),
            dconfig=_get_dask_config(),
            params=params,
            dtrain=dtrain,
            num_boost_round=self.get_num_boosting_rounds(),
            evals=evals,
            obj=obj,
            custom_metric=metric,
            verbose_eval=verbose,
            early_stopping_rounds=self.early_stopping_rounds,
            callbacks=self.callbacks,
            coll_cfg=self.coll_cfg,
            xgb_model=model,
        )
        self._Booster = results["booster"]
        self._set_evaluation_result(results["history"])
        return self

    # pylint: disable=missing-docstring, disable=unused-argument
    @_deprecate_positional_args
    def fit(
        self,
        X: _DataT,
        y: _DaskCollection,
        *,
        sample_weight: Optional[_DaskCollection] = None,
        base_margin: Optional[_DaskCollection] = None,
        eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
        verbose: Optional[Union[int, bool]] = True,
        xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
        sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
        base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
        feature_weights: Optional[_DaskCollection] = None,
    ) -> "DaskXGBRegressor":
        args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
        return self._client_sync(self._fit_async, **args)


@xgboost_model_doc(
    "Implementation of the scikit-learn API for XGBoost classification.",
    ["estimators", "model"],
)
class DaskXGBClassifier(XGBClassifierBase, DaskScikitLearnBase):
    # pylint: disable=missing-class-docstring
    async def _fit_async(
        self,
        X: _DataT,
        y: _DaskCollection,
        *,
        sample_weight: Optional[_DaskCollection],
        base_margin: Optional[_DaskCollection],
        eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
        sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
        base_margin_eval_set: Optional[Sequence[_DaskCollection]],
        verbose: Union[int, bool],
        xgb_model: Optional[Union[Booster, XGBModel]],
        feature_weights: Optional[_DaskCollection],
    ) -> "DaskXGBClassifier":
        params = self.get_xgb_params()
        model, metric, params, feature_weights = self._configure_fit(
            xgb_model, params, feature_weights
        )

        dtrain, evals = await _async_wrap_evaluation_matrices(
            self.client,
            device=self.device,
            tree_method=self.tree_method,
            max_bin=self.max_bin,
            X=X,
            y=y,
            group=None,
            qid=None,
            sample_weight=sample_weight,
            base_margin=base_margin,
            feature_weights=feature_weights,
            eval_set=eval_set,
            sample_weight_eval_set=sample_weight_eval_set,
            base_margin_eval_set=base_margin_eval_set,
            eval_group=None,
            eval_qid=None,
            missing=self.missing,
            enable_categorical=self.enable_categorical,
            feature_types=self.feature_types,
        )

        # pylint: disable=attribute-defined-outside-init
        if isinstance(y, da.Array):
            self.classes_ = await self.client.compute(da.unique(y))
        else:
            self.classes_ = await self.client.compute(y.drop_duplicates())
        if _is_cudf_ser(self.classes_):
            self.classes_ = self.classes_.to_cupy()
        if _is_cupy_alike(self.classes_):
            self.classes_ = self.classes_.get()
        self.classes_ = numpy.array(self.classes_)
        self.n_classes_ = len(self.classes_)

        if self.n_classes_ > 2:
            params["objective"] = "multi:softprob"
            params["num_class"] = self.n_classes_
        else:
            params["objective"] = "binary:logistic"

        if callable(self.objective):
            obj: Optional[Callable] = _objective_decorator(self.objective)
        else:
            obj = None
        results = await self.client.sync(
            _train_async,
            asynchronous=True,
            client=self.client,
            global_config=config.get_config(),
            dconfig=_get_dask_config(),
            params=params,
            dtrain=dtrain,
            num_boost_round=self.get_num_boosting_rounds(),
            evals=evals,
            obj=obj,
            custom_metric=metric,
            verbose_eval=verbose,
            early_stopping_rounds=self.early_stopping_rounds,
            callbacks=self.callbacks,
            coll_cfg=self.coll_cfg,
            xgb_model=model,
        )
        self._Booster = results["booster"]
        if not callable(self.objective):
            self.objective = params["objective"]
        self._set_evaluation_result(results["history"])
        return self

    # pylint: disable=unused-argument
    def fit(
        self,
        X: _DataT,
        y: _DaskCollection,
        *,
        sample_weight: Optional[_DaskCollection] = None,
        base_margin: Optional[_DaskCollection] = None,
        eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
        verbose: Optional[Union[int, bool]] = True,
        xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
        sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
        base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
        feature_weights: Optional[_DaskCollection] = None,
    ) -> "DaskXGBClassifier":
        args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
        return self._client_sync(self._fit_async, **args)

    async def _predict_proba_async(
        self,
        X: _DataT,
        validate_features: bool,
        base_margin: Optional[_DaskCollection],
        iteration_range: Optional[IterationRange],
    ) -> _DaskCollection:
        if self.objective == "multi:softmax":
            raise ValueError(
                "multi:softmax doesn't support `predict_proba`.  "
                "Switch to `multi:softproba` instead"
            )
        predts = await super()._predict_async(
            data=X,
            output_margin=False,
            validate_features=validate_features,
            base_margin=base_margin,
            iteration_range=iteration_range,
        )
        vstack = update_wrapper(
            partial(da.vstack, allow_unknown_chunksizes=True), da.vstack
        )
        return _cls_predict_proba(getattr(self, "n_classes_", 0), predts, vstack)

    # pylint: disable=missing-function-docstring
    def predict_proba(
        self,
        X: _DaskCollection,
        validate_features: bool = True,
        base_margin: Optional[_DaskCollection] = None,
        iteration_range: Optional[IterationRange] = None,
    ) -> Any:
        return self._client_sync(
            self._predict_proba_async,
            X=X,
            validate_features=validate_features,
            base_margin=base_margin,
            iteration_range=iteration_range,
        )

    predict_proba.__doc__ = XGBClassifier.predict_proba.__doc__

    async def _predict_async(
        self,
        data: _DataT,
        *,
        output_margin: bool,
        validate_features: bool,
        base_margin: Optional[_DaskCollection],
        iteration_range: Optional[IterationRange],
    ) -> _DaskCollection:
        pred_probs = await super()._predict_async(
            data,
            output_margin=output_margin,
            validate_features=validate_features,
            base_margin=base_margin,
            iteration_range=iteration_range,
        )
        if output_margin:
            return pred_probs

        if len(pred_probs.shape) == 1:
            preds = (pred_probs > 0.5).astype(int)
        else:
            assert len(pred_probs.shape) == 2
            assert isinstance(pred_probs, da.Array)
            # when using da.argmax directly, dask will construct a numpy based return
            # array, which runs into error when computing GPU based prediction.

            def _argmax(x: Any) -> Any:
                return x.argmax(axis=1)

            preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
        return preds


@xgboost_model_doc(
    """Implementation of the Scikit-Learn API for XGBoost Ranking.

    .. versionadded:: 1.4.0

""",
    ["estimators", "model"],
    extra_parameters="""
    allow_group_split :

        .. versionadded:: 3.0.0

        Whether a query group can be split among multiple workers. When set to `False`,
        inputs must be Dask dataframes or series. If you have many small query groups,
        this can significantly increase the fragmentation of the data, and the internal
        DMatrix construction can take longer.

""",
    end_note="""
        .. note::

            For the dask implementation, group is not supported, use qid instead.
""",
)
class DaskXGBRanker(XGBRankerMixIn, DaskScikitLearnBase):
    @_deprecate_positional_args
    def __init__(
        self,
        *,
        objective: str = "rank:ndcg",
        allow_group_split: bool = False,
        coll_cfg: Optional[CollConfig] = None,
        **kwargs: Any,
    ) -> None:
        if callable(objective):
            raise ValueError("Custom objective function not supported by XGBRanker.")
        self.allow_group_split = allow_group_split
        super().__init__(objective=objective, coll_cfg=coll_cfg, **kwargs)

    def _wrapper_params(self) -> Set[str]:
        params = super()._wrapper_params()
        params.add("allow_group_split")
        return params

    async def _fit_async(
        self,
        X: _DataT,
        y: _DaskCollection,
        *,
        qid: Optional[_DaskCollection],
        sample_weight: Optional[_DaskCollection],
        base_margin: Optional[_DaskCollection],
        eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
        sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
        base_margin_eval_set: Optional[Sequence[_DaskCollection]],
        eval_qid: Optional[Sequence[_DaskCollection]],
        verbose: Union[int, bool],
        xgb_model: Optional[Union[XGBModel, Booster]],
        feature_weights: Optional[_DaskCollection],
    ) -> "DaskXGBRanker":
        params = self.get_xgb_params()
        model, metric, params, feature_weights = self._configure_fit(
            xgb_model, params, feature_weights
        )
        dtrain, evals = await _async_wrap_evaluation_matrices(
            self.client,
            device=self.device,
            tree_method=self.tree_method,
            max_bin=self.max_bin,
            X=X,
            y=y,
            group=None,
            qid=qid,
            sample_weight=sample_weight,
            base_margin=base_margin,
            feature_weights=feature_weights,
            eval_set=eval_set,
            sample_weight_eval_set=sample_weight_eval_set,
            base_margin_eval_set=base_margin_eval_set,
            eval_group=None,
            eval_qid=eval_qid,
            missing=self.missing,
            enable_categorical=self.enable_categorical,
            feature_types=self.feature_types,
        )
        results = await self.client.sync(
            _train_async,
            asynchronous=True,
            client=self.client,
            global_config=config.get_config(),
            dconfig=_get_dask_config(),
            params=params,
            dtrain=dtrain,
            num_boost_round=self.get_num_boosting_rounds(),
            evals=evals,
            obj=None,
            custom_metric=metric,
            verbose_eval=verbose,
            early_stopping_rounds=self.early_stopping_rounds,
            callbacks=self.callbacks,
            xgb_model=model,
            coll_cfg=self.coll_cfg,
        )
        self._Booster = results["booster"]
        self.evals_result_ = results["history"]
        return self

    # pylint: disable=unused-argument, arguments-differ
    @_deprecate_positional_args
    def fit(
        self,
        X: _DataT,
        y: _DaskCollection,
        *,
        group: Optional[_DaskCollection] = None,
        qid: Optional[_DaskCollection] = None,
        sample_weight: Optional[_DaskCollection] = None,
        base_margin: Optional[_DaskCollection] = None,
        eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
        eval_group: Optional[Sequence[_DaskCollection]] = None,
        eval_qid: Optional[Sequence[_DaskCollection]] = None,
        verbose: Optional[Union[int, bool]] = False,
        xgb_model: Optional[Union[XGBModel, str, Booster]] = None,
        sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
        base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
        feature_weights: Optional[_DaskCollection] = None,
    ) -> "DaskXGBRanker":
        msg = "Use the `qid` instead of the `group` with the dask interface."
        if not (group is None and eval_group is None):
            raise ValueError(msg)
        if qid is None:
            raise ValueError("`qid` is required for ranking.")

        def check_df(X: _DaskCollection) -> TypeGuard[dd.DataFrame]:
            if not isinstance(X, dd.DataFrame):
                raise TypeError(
                    "When `allow_group_split` is set to False, X is required to be"
                    " a dataframe."
                )
            return True

        def check_ser(
            qid: Optional[_DaskCollection], name: str
        ) -> TypeGuard[Optional[dd.Series]]:
            if not isinstance(qid, dd.Series) and qid is not None:
                raise TypeError(
                    f"When `allow_group_split` is set to False, {name} is required to "
                    "be a series."
                )
            return True

        if not self.allow_group_split:
            assert (
                check_df(X)
                and check_ser(qid, "qid")
                and check_ser(y, "y")
                and check_ser(sample_weight, "sample_weight")
                and check_ser(base_margin, "base_margin")
            )
            assert qid is not None and y is not None
            X_id = id(X)
            X, qid, y, sample_weight, base_margin = no_group_split(
                self.device,
                X,
                qid,
                y=y,
                sample_weight=sample_weight,
                base_margin=base_margin,
            )

            if eval_set is not None:
                new_eval_set = []
                new_eval_qid = []
                new_sample_weight_eval_set = []
                new_base_margin_eval_set = []
                assert eval_qid
                for i, (Xe, ye) in enumerate(eval_set):
                    we = sample_weight_eval_set[i] if sample_weight_eval_set else None
                    be = base_margin_eval_set[i] if base_margin_eval_set else None
                    assert check_df(Xe)
                    assert eval_qid
                    qe = eval_qid[i]
                    assert (
                        eval_qid
                        and check_ser(qe, "qid")
                        and check_ser(ye, "y")
                        and check_ser(we, "sample_weight")
                        and check_ser(be, "base_margin")
                    )
                    assert qe is not None and ye is not None
                    if id(Xe) != X_id:
                        Xe, qe, ye, we, be = no_group_split(
                            self.device, Xe, qe, ye, we, be
                        )
                    else:
                        Xe, qe, ye, we, be = X, qid, y, sample_weight, base_margin

                    new_eval_set.append((Xe, ye))
                    new_eval_qid.append(qe)

                    if we is not None:
                        new_sample_weight_eval_set.append(we)
                    if be is not None:
                        new_base_margin_eval_set.append(be)

                eval_set = new_eval_set
                eval_qid = new_eval_qid
                sample_weight_eval_set = (
                    new_sample_weight_eval_set if new_sample_weight_eval_set else None
                )
                base_margin_eval_set = (
                    new_base_margin_eval_set if new_base_margin_eval_set else None
                )

        return self._client_sync(
            self._fit_async,
            X=X,
            y=y,
            qid=qid,
            sample_weight=sample_weight,
            base_margin=base_margin,
            eval_set=eval_set,
            eval_qid=eval_qid,
            verbose=verbose,
            xgb_model=xgb_model,
            sample_weight_eval_set=sample_weight_eval_set,
            base_margin_eval_set=base_margin_eval_set,
            feature_weights=feature_weights,
        )

    # FIXME(trivialfis): arguments differ due to additional parameters like group and
    # qid.
    fit.__doc__ = XGBRanker.fit.__doc__


@xgboost_model_doc(
    """Implementation of the Scikit-Learn API for XGBoost Random Forest Regressor.

    .. versionadded:: 1.4.0

""",
    ["model", "objective"],
    extra_parameters="""
    n_estimators : int
        Number of trees in random forest to fit.
""",
)
class DaskXGBRFRegressor(DaskXGBRegressor):
    @_deprecate_positional_args
    def __init__(
        self,
        *,
        learning_rate: Optional[float] = 1,
        subsample: Optional[float] = 0.8,
        colsample_bynode: Optional[float] = 0.8,
        reg_lambda: Optional[float] = 1e-5,
        coll_cfg: Optional[CollConfig] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            learning_rate=learning_rate,
            subsample=subsample,
            colsample_bynode=colsample_bynode,
            reg_lambda=reg_lambda,
            coll_cfg=coll_cfg,
            **kwargs,
        )

    def get_xgb_params(self) -> Dict[str, Any]:
        params = super().get_xgb_params()
        params["num_parallel_tree"] = self.n_estimators
        return params

    def get_num_boosting_rounds(self) -> int:
        return 1

    # pylint: disable=unused-argument
    def fit(
        self,
        X: _DataT,
        y: _DaskCollection,
        *,
        sample_weight: Optional[_DaskCollection] = None,
        base_margin: Optional[_DaskCollection] = None,
        eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
        verbose: Optional[Union[int, bool]] = True,
        xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
        sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
        base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
        feature_weights: Optional[_DaskCollection] = None,
    ) -> "DaskXGBRFRegressor":
        args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
        _check_rf_callback(self.early_stopping_rounds, self.callbacks)
        super().fit(**args)
        return self


@xgboost_model_doc(
    """Implementation of the Scikit-Learn API for XGBoost Random Forest Classifier.

    .. versionadded:: 1.4.0

""",
    ["model", "objective"],
    extra_parameters="""
    n_estimators : int
        Number of trees in random forest to fit.
""",
)
class DaskXGBRFClassifier(DaskXGBClassifier):
    @_deprecate_positional_args
    def __init__(
        self,
        *,
        learning_rate: Optional[float] = 1,
        subsample: Optional[float] = 0.8,
        colsample_bynode: Optional[float] = 0.8,
        reg_lambda: Optional[float] = 1e-5,
        coll_cfg: Optional[CollConfig] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            learning_rate=learning_rate,
            subsample=subsample,
            colsample_bynode=colsample_bynode,
            reg_lambda=reg_lambda,
            coll_cfg=coll_cfg,
            **kwargs,
        )

    def get_xgb_params(self) -> Dict[str, Any]:
        params = super().get_xgb_params()
        params["num_parallel_tree"] = self.n_estimators
        return params

    def get_num_boosting_rounds(self) -> int:
        return 1

    # pylint: disable=unused-argument
    def fit(
        self,
        X: _DataT,
        y: _DaskCollection,
        *,
        sample_weight: Optional[_DaskCollection] = None,
        base_margin: Optional[_DaskCollection] = None,
        eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]] = None,
        verbose: Optional[Union[int, bool]] = True,
        xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
        sample_weight_eval_set: Optional[Sequence[_DaskCollection]] = None,
        base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
        feature_weights: Optional[_DaskCollection] = None,
    ) -> "DaskXGBRFClassifier":
        args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
        _check_rf_callback(self.early_stopping_rounds, self.callbacks)
        super().fit(**args)
        return self
