# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates

import dataclasses
import io
import logging
import math
import sys
from bisect import bisect_right, insort
from collections import ChainMap
from typing import Any, cast, Optional, Union

import torch
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
from torch.distributed.checkpoint._nested_dict import (
    FLATTEN_MAPPING,
    flatten_state_dict,
)
from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
from torch.distributed.checkpoint._traverse import set_element
from torch.distributed.checkpoint.metadata import (
    BytesStorageMetadata,
    ChunkStorageMetadata,
    Metadata,
    MetadataIndex,
    STATE_DICT_TYPE,
    STORAGE_TYPES,
    StorageMeta,
    TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import (
    LoadPlan,
    LoadPlanner,
    ReadItem,
    SavePlan,
    SavePlanner,
    WriteItem,
    WriteItemType,
)
from torch.distributed.checkpoint.planner_helpers import (
    _compare_save_plans,
    _contains_usable_plan,
    _create_default_metadata_only_plan,
    _create_read_items,
    _create_write_items,
    _init_state_dict,
    _merge_delta_local_plans,
)
from torch.distributed.checkpoint.utils import find_state_dict_object
from torch.distributed.tensor import DTensor

from . import _version


logger: logging.Logger = logging.getLogger(__name__)


__all__ = [
    "DefaultSavePlanner",
    "DefaultLoadPlanner",
    "create_default_local_load_plan",
    "create_default_global_load_plan",
    "create_default_local_save_plan",
    "create_default_global_save_plan",
]


# TODO: Update docstrings for default_planner.py
class DefaultSavePlanner(SavePlanner):
    mappings: FLATTEN_MAPPING

    def __init__(
        self,
        flatten_state_dict: bool = True,
        flatten_sharded_tensors: bool = True,
        dedup_replicated_tensors: Optional[bool] = None,
        dedup_save_to_lowest_rank: bool = False,
        enable_plan_caching: bool = False,
    ) -> None:
        self.flatten_state_dict = flatten_state_dict
        self.flatten_sharded_tensors = flatten_sharded_tensors
        self.mappings = {}
        self.dedup_save_to_lowest_rank = dedup_save_to_lowest_rank
        if dedup_replicated_tensors is not None:
            logger.warning(
                "DefaultSavePlanner's `dedup_replicated_tensors` argument is being "
                "deprecated, and no longer has any effect. Please remove this argument "
                "from your call."
            )
        self._cached_plans_key: str = self.__class__.__name__
        self._enable_plan_caching = enable_plan_caching

    def set_up_planner(
        self,
        state_dict: STATE_DICT_TYPE,
        storage_meta: Optional[StorageMeta] = None,
        is_coordinator: bool = False,
    ) -> None:
        if self.flatten_state_dict:
            state_dict, self.mappings = flatten_state_dict(state_dict)
        if self.flatten_sharded_tensors:
            state_dict = _flatten_sharded_tensors(state_dict)
        self.state_dict = state_dict
        self.is_coordinator = is_coordinator

    def create_local_plan(self) -> SavePlan:
        plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
        if self.flatten_state_dict:
            plan = dataclasses.replace(plan, planner_data=self.mappings)
        self.plan = plan

        if self._enable_plan_caching:
            # If plans are equal, we can skip sending the plan to the coordinator.
            if (
                self._cached_plans_key in SavePlanner._cached_save_plan
                and _compare_save_plans(
                    plan, SavePlanner._cached_save_plan[self._cached_plans_key]
                )
            ):
                logger.info(
                    "No change in the local plan. Skipping sending the plan to the coordinator"
                )
                return SavePlan([], usable=False)
            else:
                SavePlanner._cached_save_plan[self._cached_plans_key] = plan

        return self.plan

    def _dedup_save_plans(self, all_plans: list[SavePlan]) -> list[SavePlan]:
        return dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank)

    def _create_global_plan(
        self, all_plans: list[SavePlan]
    ) -> tuple[list[SavePlan], Metadata]:
        deduped_plans = self._dedup_save_plans(all_plans)

        global_plan, metadata = create_default_global_save_plan(deduped_plans)

        if self.flatten_state_dict:
            # | does not work for Python 3.8 or older version.
            # merged_mappings = reduce(
            #     lambda x, y: x | y, (p.planner_data for p in global_plan)
            # )
            planner_data_dict = [p.planner_data for p in global_plan]
            merged_mappings = dict(ChainMap(*planner_data_dict))
            metadata = dataclasses.replace(metadata, planner_data=merged_mappings)

        if not _validate_global_plan(global_plan, metadata):
            raise ValueError("Failed to validate global plan")

        return global_plan, metadata

    def _create_global_plan_with_caching(
        self, all_plans: list[SavePlan]
    ) -> tuple[list[SavePlan], list[SavePlan], Metadata]:
        """
        Create global plan with caching.
        Returns a tuple of global_plan_delta, global_plan, metadata.
        """
        global_plan_delta: list[SavePlan] = []

        if self._cached_plans_key not in SavePlanner._cached_all_plans:
            # Case 1: If the plans are not cached, the cache will be hydrated with the
            # all_plans, global_plans (Deduped), and metadata.

            # Cache the original all_plans
            SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans
            global_plan, metadata = self._create_global_plan(all_plans)
            # Cache the deduped and validated global_plan
            SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan
            # Cache the metadata
            SavePlanner._cached_metadata[self._cached_plans_key] = metadata
            # If plans are not cached, global_plan delta will be the same as global plan.
            return global_plan, global_plan, metadata

        # Case 2: Plans are cached
        if not _contains_usable_plan(all_plans):
            # Case 2.1: Plans are cached and the local plans have NOT changed (No usable plans).
            # Global plan delta will be empty plans to avoid the collective overhead.
            # We can reuse the deduped global plan and metadata from the cache directly.
            global_plan_delta = [SavePlan([], usable=False)] * len(all_plans)
            global_plan = SavePlanner._cached_global_plan[self._cached_plans_key]
            metadata = SavePlanner._cached_metadata[self._cached_plans_key]
        else:
            # Case 2.2: Plans are cached but the local plans have changed.
            # We will merge the changed local plans with the cached local plans.
            # Updated plans will overwrite the cached plans. New global plan and metadata will be created and cached.
            # Global plan delta will be created by comparing the new global plan with the cached global plan.
            # Only the global plan delta (updated ones) will be sent to the coordinator to avoid the collective overhead.
            merged_plans = _merge_delta_local_plans(
                SavePlanner._cached_all_plans[self._cached_plans_key], all_plans
            )
            # Cache the updated local plans
            SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans
            global_plan, metadata = self._create_global_plan(merged_plans)

            if self._cached_plans_key in self._cached_global_plan:
                for cached_plan, new_plan in zip(
                    SavePlanner._cached_global_plan[self._cached_plans_key], global_plan
                ):
                    if _compare_save_plans(cached_plan, new_plan):
                        global_plan_delta.append(SavePlan([], usable=False))
                    else:
                        global_plan_delta.append(new_plan)

            # Cache the new global plan and the metadata
            SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan
            SavePlanner._cached_metadata[self._cached_plans_key] = metadata

        return global_plan_delta, global_plan, metadata

    def create_global_plan(
        self, all_plans: list[SavePlan]
    ) -> tuple[list[SavePlan], Metadata]:
        global_plan_delta: list[SavePlan] = []
        if self._enable_plan_caching:
            # If the plans are cached, we only need to send the global plan delta to be scattered
            # across ranks. Ranks will use the cached final plans instead.
            (
                global_plan_delta,
                global_plan,
                metadata,
            ) = self._create_global_plan_with_caching(all_plans)
        else:
            global_plan, metadata = self._create_global_plan(all_plans)
            # If the caching is not enabled, global delta plan will always be same as the new global plan.
            global_plan_delta = global_plan

        self.global_plan = global_plan
        self.metadata = metadata

        return global_plan_delta, self.metadata

    def _finish_plan_with_caching(self, new_plan: SavePlan) -> SavePlan:
        finished_plan: SavePlan = new_plan

        if not new_plan.usable:
            finished_plan = SavePlanner._cached_final_save_plan[self._cached_plans_key]
        else:
            finished_plan = new_plan
            SavePlanner._cached_final_save_plan[self._cached_plans_key] = new_plan
        return finished_plan

    def finish_plan(self, new_plan: SavePlan) -> SavePlan:
        finished_plan: SavePlan = new_plan

        if self._enable_plan_caching:
            finished_plan = self._finish_plan_with_caching(new_plan)

        self.plan = finished_plan
        return self.plan

    def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
        object = self.lookup_object(write_item.index)
        return self.transform_object(write_item, object)

    def lookup_object(self, index: MetadataIndex) -> Any:
        """Extension from the planner interface to make it easy to extend the default planner."""
        return find_state_dict_object(self.state_dict, index)

    def transform_object(self, write_item: WriteItem, object: Any):
        """Extension from the planner interface to make it easy to extend the default planner."""
        if write_item.type == WriteItemType.BYTE_IO:
            bytes = io.BytesIO()
            torch.save(object, bytes)
            object = bytes
        return object


class DefaultLoadPlanner(LoadPlanner):
    """
    DefaultLoadPlanner that adds multiple features on top of LoadPlanner.

    In particular it adds the following:

    flatten_state_dict: Handle state_dict with nested dicts
    flatten_sharded_tensors: For FSDP in 2D parallel mode
    allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint.
    """

    original_state_dict: STATE_DICT_TYPE
    mappings: FLATTEN_MAPPING

    def __init__(
        self,
        flatten_state_dict: bool = True,
        flatten_sharded_tensors: bool = True,
        allow_partial_load: bool = False,
    ) -> None:
        self.flatten_state_dict = flatten_state_dict
        self.flatten_sharded_tensors = flatten_sharded_tensors
        self.original_state_dict = {}
        self.mappings = {}
        self.allow_partial_load = allow_partial_load

    def set_up_planner(
        self,
        state_dict: STATE_DICT_TYPE,
        metadata: Optional[Metadata] = None,
        is_coordinator: bool = False,
    ) -> None:
        _init_state_dict(state_dict)
        self.original_state_dict = state_dict

        if self.flatten_sharded_tensors:
            state_dict = _flatten_sharded_tensors(state_dict)

        if self.flatten_state_dict:
            state_dict, self.mappings = flatten_state_dict(state_dict)

        self.state_dict = state_dict
        self.metadata = metadata
        self.is_coordinator = is_coordinator

    def create_local_plan(self) -> LoadPlan:
        if self.metadata is None:
            raise AssertionError("self.metadata is not None")
        if self.flatten_state_dict:
            # To support checkpoints that are saved before v2.4, we have to
            # differentiate if the missing keys are due to old checkpoints.
            # The contracts are:
            # 1. There are 3 cases when we found a missing key.
            #    1.1 Actual missing key, but allow_partial_load is False
            #    1.2 Actual missing key, but allow_partial load is True
            #    1.3 Old checkpoint, but allow_partial_load is False
            #    1.4 Old checkpoint, but allow_partial_load is True
            # 2. If we found a missing key, we first convert the keys back to
            #    the key format of v2.3
            # 3. If the previous missing keys are in the v2.3 keys, we assume
            #    this is a old checkpoint.
            # 4. Pass the state_dict to `create_default_local_load_plan()`,
            #    which has the logic to check missing for allow_partial_load.
            # So for 1.2 and 1.4 cases, we delegate allow_partial_load check to
            # `create_default_local_load_plan()`. The logic here is to determine
            # whether the checkpoint belong to 2.3 (or before) or 2.4 (or after).
            current_keys = set(self.state_dict.keys())
            load_keys = set(self.metadata.state_dict_metadata.keys())
            missing_keys = load_keys - current_keys
            if missing_keys:
                _version._derived_version = "2_3"
                old_state_dict, old_mappings = flatten_state_dict(
                    self.original_state_dict
                )
                old_keys = set(old_state_dict.keys())
                if old_keys & missing_keys:
                    self.state_dict, self.mappings = old_state_dict, old_mappings
                # _derived_version is only used by flatten_state_dict now.
                # Set it back to None so that later we can save to a new version.
                _version._derived_version = None

        return create_default_local_load_plan(
            self.state_dict, self.metadata, not self.allow_partial_load
        )

    def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]:
        return create_default_global_load_plan(global_plan)

    def finish_plan(self, new_plan: LoadPlan) -> LoadPlan:
        return new_plan

    def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
        if self.flatten_state_dict:
            set_element(
                self.original_state_dict,
                self.mappings[read_item.dest_index.fqn],
                torch.load(value, weights_only=False),
            )
        else:
            self.state_dict[read_item.dest_index.fqn] = torch.load(
                value, weights_only=False
            )

    def resolve_tensor(self, read_item: ReadItem):
        tensor = self.lookup_tensor(read_item.dest_index)
        return self.transform_tensor(read_item, tensor)

    def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
        pass

    def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
        """Extension from the planner interface to make it easy to extend the default planner."""
        return find_state_dict_object(self.state_dict, index)

    def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
        """Extension from the planner interface to make it easy to extend the default planner."""
        return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths)


class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
    """
    Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata.
    Useful for loading in state_dict without first initializing a model, such as
    when converting a DCP checkpoint into a Torch save file.

    . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner

    .. warning::
        Because the entire state dict is initialized, It's recommended to only utilize
        this LoadPlanner on a single rank or process to avoid OOM.

    """

    def __init__(self, keys=None, *args, **kwargs):
        self.keys = keys
        super().__init__(*args, **kwargs)

    def _should_include_key(self, key: str, metadata: Metadata) -> bool:
        if self.keys is None:
            return True

        if key in self.keys:
            return True

        unflattened_keys: list[str] = []
        planner_data = metadata.planner_data.get(key)
        for unflattened_key in planner_data:
            if unflattened_keys:
                unflattened_keys.append(
                    ".".join([unflattened_keys[-1], str(unflattened_key)])
                )

            else:
                unflattened_keys.append(unflattened_key)

        if any(unflattened_key in self.keys for unflattened_key in unflattened_keys):
            return True

        return False

    def set_up_planner(
        self,
        state_dict: STATE_DICT_TYPE,
        metadata: Optional[Metadata] = None,
        is_coordinator: bool = False,
    ) -> None:
        if state_dict:
            raise AssertionError("not state_dict")
        if metadata is None:
            raise AssertionError("metadata is not None")

        # rebuild the state dict from the metadata
        for k, v in metadata.state_dict_metadata.items():
            if not self._should_include_key(k, metadata):
                continue

            if isinstance(v, TensorStorageMetadata):
                v = torch.empty(v.size, dtype=v.properties.dtype)  # type: ignore[assignment]
            if metadata.planner_data is not None and k in metadata.planner_data:
                set_element(state_dict, metadata.planner_data[k], v)
            else:
                state_dict[k] = v

        super().set_up_planner(state_dict, metadata, is_coordinator)


def create_default_local_load_plan(
    state_dict: dict[str, Any], metadata: Metadata, strict: bool = True
) -> LoadPlan:
    requests = []
    """
    Create the ``LoadPlan`` used by DefaultLoadPlanner.

    It produces one read item per value in ``state_dict`` using the metadata in ``metadata``.

    The default behavior is to match key exactly between state_dict and metadata.
    It handles resharding by issuing multiple read requests against storage in order to match
    load requirements.
    """

    for fqn, obj in state_dict.items():
        # ignore state_dict keys which do not exist in `state_dict` if strict=False
        if fqn not in metadata.state_dict_metadata:
            if strict:
                raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
            else:
                continue

        md = metadata.state_dict_metadata[fqn]
        if (
            isinstance(md, TensorStorageMetadata)
            and getattr(obj, "size", None) is not None
            and md.size != obj.size()
        ):
            raise ValueError(
                f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}",
            )
        # Since DTensor supports submesh, adding extra check to ensure _create_read_items()
        # gets called only when the current rank is part of the mesh for the corresponding DTensor.
        if isinstance(obj, DTensor):
            if obj.device_mesh.get_coordinate() is not None:
                requests += _create_read_items(fqn, md, obj)
        else:
            requests += _create_read_items(fqn, md, obj)

    return LoadPlan(requests)


def create_default_global_load_plan(
    all_plans: list[LoadPlan],
) -> list[LoadPlan]:
    """
    Create global load plan used by DefaultLoadPlanner.

    The default load behavior involved no global coordination and this function
    currently doesn't change the local plans.
    """
    return all_plans


def create_default_local_save_plan(
    state_dict: dict[str, Any], is_coordinator: bool
) -> SavePlan:
    """
    Create the ``SavePlan`` used by DefaultSavePlanner.

    On non-coordinator ranks, this function ignores tensors and non-tensor objects,
    only producing writes for ShardedTensor objects.

    On the coordinator rank, produce writes for all values.
    """
    requests = []
    for fqn, obj in state_dict.items():
        # Since DTensor supports submesh, adding extra check to ensure _create_write_items()
        # gets called only when the current rank is part of the mesh for the corresponding DTensor.
        if isinstance(obj, DTensor):
            if obj.device_mesh.get_coordinate() is not None:
                requests += _create_write_items(fqn, obj)
        else:
            # For the plain tensor and non-tensor values, add the request for all
            # the ranks. Coordinator will decides whether to deduplicate the
            # values based on the keys.
            requests += _create_write_items(fqn, obj)

    return SavePlan(requests)


def create_default_global_save_plan(
    all_plans: list[SavePlan],
    rewrite_index_hints: bool = True,
) -> tuple[list[SavePlan], Metadata]:
    """
    Create the global plan and metadata used by DefaultSavePlanner.

    Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans.

    The only global planning change is to update index hints in all ``MetadataIndex`` objects if
    ``rewrite_index_hints`` is True.
    """
    md: dict[str, STORAGE_TYPES] = {}
    new_plans = []
    for plan in all_plans:
        new_items = []
        for item in plan.items:
            if item.type != WriteItemType.SHARD:
                if item.index.fqn in md:
                    raise AssertionError("item.index.fqn not in md")

            if item.type == WriteItemType.BYTE_IO:
                md[item.index.fqn] = BytesStorageMetadata()
                new_items.append(item)
            else:
                if item.tensor_data is None:
                    raise AssertionError("item.tensor_data is not None")
                tensor_md = cast(
                    TensorStorageMetadata,
                    md.setdefault(
                        item.index.fqn,
                        TensorStorageMetadata(
                            properties=item.tensor_data.properties,
                            size=item.tensor_data.size,
                            chunks=[],
                        ),
                    ),
                )
                new_item = item
                if rewrite_index_hints:
                    new_index = dataclasses.replace(
                        item.index, index=len(tensor_md.chunks)
                    )
                    new_item = dataclasses.replace(item, index=new_index)
                new_items.append(new_item)

                if item.tensor_data.chunk is None:
                    raise AssertionError(f"""
                    Cannot create MD for tensor without bounds.
                    FQN: {item.index.fqn}
                """)
                tensor_md.chunks.append(item.tensor_data.chunk)
        new_plans.append(dataclasses.replace(plan, items=new_items))
    return (new_plans, Metadata(md))


def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata:
    """Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``."""
    plan = _create_default_metadata_only_plan(state_dict)
    _, md = create_default_global_save_plan([plan])
    return md


def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool:
    """Check if two boxes overlap. Tuples are (offset, lengths)."""
    # For each dim of each shard, check if one shard resides on the other
    # end of second shard with respect to that dim. As an example for a 2D
    # shard, we would check if one shard is above or on the left of the
    # other shard.
    ndims = len(box0.offsets)
    for i in range(ndims):
        if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]:
            return False
        if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]:
            return False

    return True


def _check_box_bounds(
    outer_box_size: torch.Size, inner_box: ChunkStorageMetadata
) -> bool:
    for i in range(len(outer_box_size)):
        if inner_box.offsets[i] < 0:
            return False
        if inner_box.sizes[i] < 0:
            return False
        if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]:
            return False

    return True


def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bool:
    all_good = True
    for key, value in metadata.state_dict_metadata.items():
        if isinstance(value, BytesStorageMetadata):
            continue
        if len(value.size) == 0:
            continue
        chunks = value.chunks
        chunks_volume = 0
        for chunk in chunks:
            # Compute the volume
            if not _check_box_bounds(value.size, chunk):
                logger.warning(
                    """
                        key:%s has out of bounds chunk:
                        tensor-size:%s chunk: %s
                    """,
                    key,
                    value.size,
                    chunk,
                )
                all_good = False
            chunks_volume += math.prod(chunk.sizes)

        if len(chunks) > 1:
            dims = len(value.size)
            sweep_dim = max(range(dims), default=0, key=lambda d: value.size[d])
            sorted_indices = sorted(
                range(len(chunks)),
                key=lambda idx: (
                    chunks[idx].offsets[sweep_dim],
                    *(chunks[idx].offsets[d] for d in range(dims)),
                ),
            )
            active: list[tuple[int, int]] = []
            for idx in sorted_indices:
                current = chunks[idx]
                start = current.offsets[sweep_dim]
                end = start + current.sizes[sweep_dim]

                cutoff = bisect_right(active, (start, sys.maxsize))
                if cutoff:
                    del active[:cutoff]

                for _, other_idx in active:
                    other = chunks[other_idx]
                    if _check_box_overlap(current, other):
                        logger.warning(
                            "key:%s has overlapping chunks: %s %s",
                            key,
                            current,
                            other,
                        )
                        all_good = False

                insort(active, (end, idx))

        # Check whether combined chunk cover the whole tensor
        tensor_volume = math.prod(value.size)
        if len(global_plan) > 1 and chunks_volume != tensor_volume:
            logger.warning(
                """
                    key:%s invalid fill tensor-volume:
                    %s chunks-volume: %s
                """,
                key,
                tensor_volume,
                chunks_volume,
            )
            all_good = False

    return all_good
