"""
This module tests the functionality of StringArray and ArrowStringArray.
Tests for the str accessors are in pandas/tests/strings/test_string_array.py
"""

import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.compat.pyarrow import pa_version_under19p0

from pandas.core.dtypes.common import is_dtype_equal

import pandas as pd
import pandas._testing as tm
from pandas.core.arrays.string_arrow import (
    ArrowStringArray,
)


@pytest.fixture
def dtype(string_dtype_arguments):
    """Fixture giving StringDtype from parametrized storage and na_value arguments"""
    storage, na_value = string_dtype_arguments
    return pd.StringDtype(storage=storage, na_value=na_value)


@pytest.fixture
def dtype2(string_dtype_arguments2):
    storage, na_value = string_dtype_arguments2
    return pd.StringDtype(storage=storage, na_value=na_value)


@pytest.fixture
def cls(dtype):
    """Fixture giving array type from parametrized 'dtype'"""
    return dtype.construct_array_type()


def test_dtype_equality():
    pytest.importorskip("pyarrow")

    dtype1 = pd.StringDtype("python")
    dtype2 = pd.StringDtype("pyarrow")
    dtype3 = pd.StringDtype("pyarrow", na_value=np.nan)

    assert dtype1 == pd.StringDtype("python", na_value=pd.NA)
    assert dtype1 != dtype2
    assert dtype1 != dtype3

    assert dtype2 == pd.StringDtype("pyarrow", na_value=pd.NA)
    assert dtype2 != dtype1
    assert dtype2 != dtype3

    assert dtype3 == pd.StringDtype("pyarrow", na_value=np.nan)
    assert dtype3 == pd.StringDtype("pyarrow", na_value=float("nan"))
    assert dtype3 != dtype1
    assert dtype3 != dtype2


def test_repr(dtype):
    df = pd.DataFrame({"A": pd.array(["a", pd.NA, "b"], dtype=dtype)})
    if dtype.na_value is np.nan:
        expected = "     A\n0    a\n1  NaN\n2    b"
    else:
        expected = "      A\n0     a\n1  <NA>\n2     b"
    assert repr(df) == expected

    if dtype.na_value is np.nan:
        expected = "0      a\n1    NaN\n2      b\nName: A, dtype: str"
    else:
        expected = "0       a\n1    <NA>\n2       b\nName: A, dtype: string"
    assert repr(df.A) == expected

    if dtype.storage == "pyarrow" and dtype.na_value is pd.NA:
        arr_name = "ArrowStringArray"
        expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
    elif dtype.storage == "pyarrow" and dtype.na_value is np.nan:
        arr_name = "ArrowStringArray"
        expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: str"
    elif dtype.storage == "python" and dtype.na_value is np.nan:
        arr_name = "StringArray"
        expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: str"
    else:
        arr_name = "StringArray"
        expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
    assert repr(df.A.array) == expected


def test_dtype_repr(dtype):
    if dtype.storage == "pyarrow":
        if dtype.na_value is pd.NA:
            assert repr(dtype) == "<StringDtype(na_value=<NA>)>"
        else:
            assert repr(dtype) == "<StringDtype(na_value=nan)>"
    elif dtype.na_value is pd.NA:
        assert repr(dtype) == "<StringDtype(storage='python', na_value=<NA>)>"
    else:
        assert repr(dtype) == "<StringDtype(storage='python', na_value=nan)>"


def test_none_to_nan(cls, dtype):
    a = cls._from_sequence(["a", None, "b"], dtype=dtype)
    assert a[1] is not None
    assert a[1] is a.dtype.na_value


def test_setitem_validates(cls, dtype):
    arr = cls._from_sequence(["a", "b"], dtype=dtype)

    msg = "Invalid value '10' for dtype 'str"
    with pytest.raises(TypeError, match=msg):
        arr[0] = 10

    msg = "Invalid value for dtype 'str"
    with pytest.raises(TypeError, match=msg):
        arr[:] = np.array([1, 2])


def test_setitem_with_scalar_string(dtype):
    # is_float_dtype considers some strings, like 'd', to be floats
    # which can cause issues.
    arr = pd.array(["a", "c"], dtype=dtype)
    arr[0] = "d"
    expected = pd.array(["d", "c"], dtype=dtype)
    tm.assert_extension_array_equal(arr, expected)


def test_setitem_with_array_with_missing(dtype):
    # ensure that when setting with an array of values, we don't mutate the
    # array `value` in __setitem__(self, key, value)
    arr = pd.array(["a", "b", "c"], dtype=dtype)
    value = np.array(["A", None])
    value_orig = value.copy()
    arr[[0, 1]] = value

    expected = pd.array(["A", pd.NA, "c"], dtype=dtype)
    tm.assert_extension_array_equal(arr, expected)
    tm.assert_numpy_array_equal(value, value_orig)


def test_astype_roundtrip(dtype):
    ser = pd.Series(pd.date_range("2000", periods=12, unit="ns"))
    ser[0] = None

    casted = ser.astype(dtype)
    assert is_dtype_equal(casted.dtype, dtype)

    result = casted.astype("datetime64[ns]")
    tm.assert_series_equal(result, ser)

    # GH#38509 same thing for timedelta64
    ser2 = ser - ser.iloc[-1]
    casted2 = ser2.astype(dtype)
    assert is_dtype_equal(casted2.dtype, dtype)

    result2 = casted2.astype(ser2.dtype)
    tm.assert_series_equal(result2, ser2)


def test_constructor_raises(cls):
    if cls is pd.arrays.StringArray:
        msg = "StringArray requires a sequence of strings or pandas.NA"
        kwargs = {"dtype": pd.StringDtype()}
    else:
        msg = "Unsupported type '<class 'numpy.ndarray'>' for ArrowExtensionArray"
        kwargs = {}

    with pytest.raises(ValueError, match=msg):
        cls(np.array(["a", "b"], dtype="S1"), **kwargs)

    with pytest.raises(ValueError, match=msg):
        cls(np.array([]), **kwargs)

    if cls is pd.arrays.StringArray:
        # GH#45057 np.nan and None do NOT raise, as they are considered valid NAs
        #  for string dtype
        cls(np.array(["a", np.nan], dtype=object), **kwargs)
        cls(np.array(["a", None], dtype=object), **kwargs)
    else:
        with pytest.raises(ValueError, match=msg):
            cls(np.array(["a", np.nan], dtype=object), **kwargs)
        with pytest.raises(ValueError, match=msg):
            cls(np.array(["a", None], dtype=object), **kwargs)

    with pytest.raises(ValueError, match=msg):
        cls(np.array(["a", pd.NaT], dtype=object), **kwargs)

    with pytest.raises(ValueError, match=msg):
        cls(np.array(["a", np.datetime64("NaT", "ns")], dtype=object), **kwargs)

    with pytest.raises(ValueError, match=msg):
        cls(np.array(["a", np.timedelta64("NaT", "ns")], dtype=object), **kwargs)


@pytest.mark.parametrize("na", [np.nan, np.float64("nan"), float("nan"), None, pd.NA])
def test_constructor_nan_like(na):
    expected = pd.arrays.StringArray(np.array(["a", pd.NA]), dtype=pd.StringDtype())
    result = pd.arrays.StringArray(
        np.array(["a", na], dtype="object"), dtype=pd.StringDtype()
    )
    tm.assert_extension_array_equal(result, expected)


@pytest.mark.parametrize("copy", [True, False])
def test_from_sequence_no_mutate(copy, cls, dtype):
    nan_arr = np.array(["a", np.nan], dtype=object)
    expected_input = nan_arr.copy()
    na_arr = np.array(["a", pd.NA], dtype=object)

    result = cls._from_sequence(nan_arr, dtype=dtype, copy=copy)

    if cls is ArrowStringArray:
        import pyarrow as pa

        expected = cls(
            pa.array(na_arr, type=pa.string(), from_pandas=True), dtype=dtype
        )
    elif dtype.na_value is np.nan:
        expected = cls(nan_arr, dtype=dtype)
    else:
        expected = cls(na_arr, dtype=dtype)

    tm.assert_extension_array_equal(result, expected)
    tm.assert_numpy_array_equal(nan_arr, expected_input)


def test_astype_int(dtype):
    arr = pd.array(["1", "2", "3"], dtype=dtype)
    result = arr.astype("int64")
    expected = np.array([1, 2, 3], dtype="int64")
    tm.assert_numpy_array_equal(result, expected)

    arr = pd.array(["1", pd.NA, "3"], dtype=dtype)
    if dtype.na_value is np.nan:
        err = ValueError
        msg = "cannot convert float NaN to integer"
    else:
        err = TypeError
        msg = (
            r"int\(\) argument must be a string, a bytes-like "
            r"object or a( real)? number"
        )
    with pytest.raises(err, match=msg):
        arr.astype("int64")


def test_astype_nullable_int(dtype):
    arr = pd.array(["1", pd.NA, "3"], dtype=dtype)

    result = arr.astype("Int64")
    expected = pd.array([1, pd.NA, 3], dtype="Int64")
    tm.assert_extension_array_equal(result, expected)


def test_astype_float(dtype, any_float_dtype):
    # Don't compare arrays (37974)
    ser = pd.Series(["1.1", pd.NA, "3.3"], dtype=dtype)
    result = ser.astype(any_float_dtype)
    item = np.nan if isinstance(result.dtype, np.dtype) else pd.NA
    expected = pd.Series([1.1, item, 3.3], dtype=any_float_dtype)
    tm.assert_series_equal(result, expected)


def test_reduce(skipna, dtype):
    arr = pd.Series(["a", "b", "c"], dtype=dtype)
    result = arr.sum(skipna=skipna)
    assert result == "abc"


def test_reduce_missing(skipna, dtype):
    arr = pd.Series([None, "a", None, "b", "c", None], dtype=dtype)
    result = arr.sum(skipna=skipna)
    if skipna:
        assert result == "abc"
    else:
        assert pd.isna(result)


@pytest.mark.parametrize("method", ["min", "max"])
def test_min_max(method, skipna, dtype):
    arr = pd.Series(["a", "b", "c", None], dtype=dtype)
    result = getattr(arr, method)(skipna=skipna)
    if skipna:
        expected = "a" if method == "min" else "c"
        assert result == expected
    else:
        assert result is arr.dtype.na_value


@pytest.mark.parametrize("method", ["min", "max"])
@pytest.mark.parametrize("box", [pd.Series, pd.array])
def test_min_max_numpy(method, box, dtype, request):
    if dtype.storage == "pyarrow" and box is pd.array:
        if box is pd.array:
            reason = "'<=' not supported between instances of 'str' and 'NoneType'"
        else:
            reason = "'ArrowStringArray' object has no attribute 'max'"
        mark = pytest.mark.xfail(raises=TypeError, reason=reason)
        request.applymarker(mark)

    arr = box(["a", "b", "c", None], dtype=dtype)
    result = getattr(np, method)(arr)
    expected = "a" if method == "min" else "c"
    assert result == expected


def test_fillna_args(dtype):
    # GH 37987

    arr = pd.array(["a", pd.NA], dtype=dtype)

    res = arr.fillna(value="b")
    expected = pd.array(["a", "b"], dtype=dtype)
    tm.assert_extension_array_equal(res, expected)

    res = arr.fillna(value=np.str_("b"))
    expected = pd.array(["a", "b"], dtype=dtype)
    tm.assert_extension_array_equal(res, expected)

    msg = "Invalid value '1' for dtype 'str"
    with pytest.raises(TypeError, match=msg):
        arr.fillna(value=1)


def test_arrow_array(dtype):
    # protocol added in 0.15.0
    pa = pytest.importorskip("pyarrow")
    import pyarrow.compute as pc

    data = pd.array(["a", "b", "c"], dtype=dtype)
    arr = pa.array(data)
    expected = pa.array(list(data), type=pa.large_string(), from_pandas=True)
    if dtype.storage == "python":
        expected = pc.cast(expected, pa.string())
    assert arr.equals(expected)


@pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning")
def test_arrow_roundtrip(dtype, string_storage, using_infer_string):
    # roundtrip possible from arrow 1.0.0
    pa = pytest.importorskip("pyarrow")

    data = pd.array(["a", "b", None], dtype=dtype)
    df = pd.DataFrame({"a": data})
    table = pa.table(df)
    if dtype.storage == "python":
        assert table.field("a").type == "string"
    else:
        assert table.field("a").type == "large_string"
    with pd.option_context("string_storage", string_storage):
        result = table.to_pandas()
    if dtype.na_value is np.nan and not using_infer_string:
        assert result["a"].dtype == "object"
    else:
        assert isinstance(result["a"].dtype, pd.StringDtype)
        expected = df.astype(pd.StringDtype(string_storage, na_value=dtype.na_value))
        if using_infer_string:
            expected.columns = expected.columns.astype(
                pd.StringDtype(string_storage, na_value=np.nan)
            )
        tm.assert_frame_equal(result, expected)
        # ensure the missing value is represented by NA and not np.nan or None
        assert result.loc[2, "a"] is result["a"].dtype.na_value


@pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning")
def test_arrow_from_string(using_infer_string):
    # not roundtrip,  but starting with pyarrow table without pandas metadata
    pa = pytest.importorskip("pyarrow")
    table = pa.table({"a": pa.array(["a", "b", None], type=pa.string())})

    result = table.to_pandas()

    if using_infer_string and not pa_version_under19p0:
        expected = pd.DataFrame({"a": ["a", "b", None]}, dtype="str")
    else:
        expected = pd.DataFrame({"a": ["a", "b", None]}, dtype="object")
    tm.assert_frame_equal(result, expected)


@pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning")
def test_arrow_load_from_zero_chunks(dtype, string_storage, using_infer_string):
    # GH-41040
    pa = pytest.importorskip("pyarrow")

    data = pd.array([], dtype=dtype)
    df = pd.DataFrame({"a": data})
    table = pa.table(df)
    if dtype.storage == "python":
        assert table.field("a").type == "string"
    else:
        assert table.field("a").type == "large_string"
    # Instantiate the same table with no chunks at all
    table = pa.table([pa.chunked_array([], type=pa.string())], schema=table.schema)
    with pd.option_context("string_storage", string_storage):
        result = table.to_pandas()

    if dtype.na_value is np.nan and not using_string_dtype():
        assert result["a"].dtype == "object"
    else:
        assert isinstance(result["a"].dtype, pd.StringDtype)
        expected = df.astype(pd.StringDtype(string_storage, na_value=dtype.na_value))
        if using_infer_string:
            expected.columns = expected.columns.astype(
                pd.StringDtype(string_storage, na_value=np.nan)
            )
        tm.assert_frame_equal(result, expected)


def test_value_counts_na(dtype):
    if dtype.na_value is np.nan:
        exp_dtype = "int64"
    elif dtype.storage == "pyarrow":
        exp_dtype = "int64[pyarrow]"
    else:
        exp_dtype = "Int64"
    arr = pd.array(["a", "b", "a", pd.NA], dtype=dtype)
    result = arr.value_counts(dropna=False)
    expected = pd.Series([2, 1, 1], index=arr[[0, 1, 3]], dtype=exp_dtype, name="count")
    tm.assert_series_equal(result, expected)

    result = arr.value_counts(dropna=True)
    expected = pd.Series([2, 1], index=arr[:2], dtype=exp_dtype, name="count")
    tm.assert_series_equal(result, expected)


def test_value_counts_with_normalize(dtype):
    if dtype.na_value is np.nan:
        exp_dtype = np.float64
    elif dtype.storage == "pyarrow":
        exp_dtype = "double[pyarrow]"
    else:
        exp_dtype = "Float64"
    ser = pd.Series(["a", "b", "a", pd.NA], dtype=dtype)
    result = ser.value_counts(normalize=True)
    expected = pd.Series([2, 1], index=ser[:2], dtype=exp_dtype, name="proportion") / 3
    tm.assert_series_equal(result, expected)


def test_value_counts_sort_false(dtype):
    if dtype.na_value is np.nan:
        exp_dtype = "int64"
    elif dtype.storage == "pyarrow":
        exp_dtype = "int64[pyarrow]"
    else:
        exp_dtype = "Int64"
    ser = pd.Series(["a", "b", "c", "b"], dtype=dtype)
    result = ser.value_counts(sort=False)
    expected = pd.Series([1, 2, 1], index=ser[:3], dtype=exp_dtype, name="count")
    tm.assert_series_equal(result, expected)


def test_memory_usage(dtype):
    # GH 33963

    if dtype.storage == "pyarrow":
        pytest.skip(f"not applicable for {dtype.storage}")

    series = pd.Series(["a", "b", "c"], dtype=dtype)

    assert 0 < series.nbytes <= series.memory_usage() < series.memory_usage(deep=True)


@pytest.mark.parametrize("float_dtype", [np.float16, np.float32, np.float64])
def test_astype_from_float_dtype(float_dtype, dtype):
    # https://github.com/pandas-dev/pandas/issues/36451
    ser = pd.Series([0.1], dtype=float_dtype)
    result = ser.astype(dtype)
    expected = pd.Series(["0.1"], dtype=dtype)
    tm.assert_series_equal(result, expected)


def test_to_numpy_returns_pdna_default(dtype):
    arr = pd.array(["a", pd.NA, "b"], dtype=dtype)
    result = np.array(arr)
    expected = np.array(["a", dtype.na_value, "b"], dtype=object)
    tm.assert_numpy_array_equal(result, expected)


def test_to_numpy_na_value(dtype, nulls_fixture):
    na_value = nulls_fixture
    arr = pd.array(["a", pd.NA, "b"], dtype=dtype)
    result = arr.to_numpy(na_value=na_value)
    expected = np.array(["a", na_value, "b"], dtype=object)
    tm.assert_numpy_array_equal(result, expected)


def test_to_numpy_readonly(dtype):
    arr = pd.array(["a", pd.NA, "b"], dtype=dtype)
    arr._readonly = True
    result = arr.to_numpy()
    if dtype.storage == "python":
        assert not result.flags.writeable
    else:
        assert result.flags.writeable


def test_isin(dtype, fixed_now_ts):
    s = pd.Series(["a", "b", None], dtype=dtype)

    result = s.isin(["a", "c"])
    expected = pd.Series([True, False, False])
    tm.assert_series_equal(result, expected)

    result = s.isin(["a", pd.NA])
    expected = pd.Series([True, False, True])
    tm.assert_series_equal(result, expected)

    result = s.isin([])
    expected = pd.Series([False, False, False])
    tm.assert_series_equal(result, expected)

    result = s.isin(["a", fixed_now_ts])
    expected = pd.Series([True, False, False])
    tm.assert_series_equal(result, expected)

    result = s.isin([fixed_now_ts])
    expected = pd.Series([False, False, False])
    tm.assert_series_equal(result, expected)


def test_isin_string_array(dtype, dtype2):
    s = pd.Series(["a", "b", None], dtype=dtype)

    result = s.isin(pd.array(["a", "c"], dtype=dtype2))
    expected = pd.Series([True, False, False])
    tm.assert_series_equal(result, expected)

    result = s.isin(pd.array(["a", None], dtype=dtype2))
    expected = pd.Series([True, False, True])
    tm.assert_series_equal(result, expected)


def test_isin_arrow_string_array(dtype):
    pa = pytest.importorskip("pyarrow")
    s = pd.Series(["a", "b", None], dtype=dtype)

    result = s.isin(pd.array(["a", "c"], dtype=pd.ArrowDtype(pa.string())))
    expected = pd.Series([True, False, False])
    tm.assert_series_equal(result, expected)

    result = s.isin(pd.array(["a", None], dtype=pd.ArrowDtype(pa.string())))
    expected = pd.Series([True, False, True])
    tm.assert_series_equal(result, expected)


def test_setitem_scalar_with_mask_validation(dtype):
    # https://github.com/pandas-dev/pandas/issues/47628
    # setting None with a boolean mask (through _putmaks) should still result
    # in pd.NA values in the underlying array
    ser = pd.Series(["a", "b", "c"], dtype=dtype)
    mask = np.array([False, True, False])

    ser[mask] = None
    assert ser.array[1] is ser.dtype.na_value

    # for other non-string we should also raise an error
    ser = pd.Series(["a", "b", "c"], dtype=dtype)
    msg = "Invalid value '1' for dtype 'str"
    with pytest.raises(TypeError, match=msg):
        ser[mask] = 1


def test_from_numpy_str(dtype):
    vals = ["a", "b", "c"]
    arr = np.array(vals, dtype=np.str_)
    result = pd.array(arr, dtype=dtype)
    expected = pd.array(vals, dtype=dtype)
    tm.assert_extension_array_equal(result, expected)


def test_tolist(dtype):
    vals = ["a", "b", "c"]
    arr = pd.array(vals, dtype=dtype)
    result = arr.tolist()
    expected = vals
    tm.assert_equal(result, expected)


def test_string_array_view_type_error():
    arr = pd.array(["a", "b", "c"], dtype="string")
    with pytest.raises(TypeError, match="Cannot change data-type for string array."):
        arr.view("i8")


@pytest.mark.parametrize("box", [pd.Series, pd.array])
def test_numpy_array_ufunc(dtype, box):
    arr = box(["a", "bb", "ccc"], dtype=dtype)

    # custom ufunc that works with string (object) input -> returning numeric
    str_len_ufunc = np.frompyfunc(lambda x: len(x), 1, 1)
    result = str_len_ufunc(arr)
    expected_cls = pd.Series if box is pd.Series else np.array
    # TODO we should infer int64 dtype here?
    expected = expected_cls([1, 2, 3], dtype=object)
    tm.assert_equal(result, expected)

    # custom ufunc returning strings
    str_multiply_ufunc = np.frompyfunc(lambda x: x * 2, 1, 1)
    result = str_multiply_ufunc(arr)
    expected = box(["aa", "bbbb", "cccccc"], dtype=dtype)
    if dtype.storage == "pyarrow":
        # TODO ArrowStringArray should also preserve the class / dtype
        if box is pd.array:
            expected = np.array(["aa", "bbbb", "cccccc"], dtype=object)
        else:
            # not specifying the dtype because the exact dtype is not yet preserved
            expected = pd.Series(["aa", "bbbb", "cccccc"])

    tm.assert_equal(result, expected)
