from __future__ import annotations

from functools import partial
from typing import Callable, Final

import mypy.errorcodes as codes
from mypy import message_registry
from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr
from mypy.plugin import (
    AttributeContext,
    ClassDefContext,
    FunctionContext,
    FunctionSigContext,
    MethodContext,
    MethodSigContext,
    Plugin,
)
from mypy.plugins.attrs import (
    attr_class_maker_callback,
    attr_class_makers,
    attr_dataclass_makers,
    attr_define_makers,
    attr_frozen_makers,
    attr_tag_callback,
    evolve_function_sig_callback,
    fields_function_sig_callback,
)
from mypy.plugins.common import try_getting_str_literals
from mypy.plugins.constants import (
    ENUM_NAME_ACCESS,
    ENUM_VALUE_ACCESS,
    SINGLEDISPATCH_CALLABLE_CALL_METHOD,
    SINGLEDISPATCH_REGISTER_CALLABLE_CALL_METHOD,
    SINGLEDISPATCH_REGISTER_METHOD,
)
from mypy.plugins.ctypes import (
    array_constructor_callback,
    array_getitem_callback,
    array_iter_callback,
    array_raw_callback,
    array_setitem_callback,
    array_value_callback,
)
from mypy.plugins.dataclasses import (
    dataclass_class_maker_callback,
    dataclass_makers,
    dataclass_tag_callback,
    replace_function_sig_callback,
)
from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback
from mypy.plugins.functools import (
    functools_total_ordering_maker_callback,
    functools_total_ordering_makers,
    partial_call_callback,
    partial_new_callback,
)
from mypy.plugins.singledispatch import (
    call_singledispatch_function_after_register_argument,
    call_singledispatch_function_callback,
    create_singledispatch_function_callback,
    singledispatch_register_callback,
)
from mypy.subtypes import is_subtype
from mypy.typeops import is_literal_type_like, make_simplified_union
from mypy.types import (
    TPDICT_FB_NAMES,
    AnyType,
    CallableType,
    FunctionLike,
    Instance,
    LiteralType,
    NoneType,
    TupleType,
    Type,
    TypedDictType,
    TypeOfAny,
    TypeVarType,
    UnionType,
    get_proper_type,
    get_proper_types,
)

TD_SETDEFAULT_NAMES: Final = {n + ".setdefault" for n in TPDICT_FB_NAMES}
TD_POP_NAMES: Final = {n + ".pop" for n in TPDICT_FB_NAMES}
TD_DELITEM_NAMES: Final = {n + ".__delitem__" for n in TPDICT_FB_NAMES}

TD_UPDATE_METHOD_NAMES: Final = (
    {n + ".update" for n in TPDICT_FB_NAMES}
    | {n + ".__or__" for n in TPDICT_FB_NAMES}
    | {n + ".__ror__" for n in TPDICT_FB_NAMES}
    | {n + ".__ior__" for n in TPDICT_FB_NAMES}
)


class DefaultPlugin(Plugin):
    """Type checker plugin that is enabled by default."""

    def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
        if fullname == "_ctypes.Array":
            return array_constructor_callback
        elif fullname == "functools.singledispatch":
            return create_singledispatch_function_callback
        elif fullname == "functools.partial":
            return partial_new_callback
        elif fullname == "enum.member":
            return enum_member_callback
        return None

    def get_function_signature_hook(
        self, fullname: str
    ) -> Callable[[FunctionSigContext], FunctionLike] | None:
        if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
            return evolve_function_sig_callback
        elif fullname in ("attr.fields", "attrs.fields"):
            return fields_function_sig_callback
        elif fullname == "dataclasses.replace":
            return replace_function_sig_callback
        return None

    def get_method_signature_hook(
        self, fullname: str
    ) -> Callable[[MethodSigContext], FunctionLike] | None:
        if fullname == "typing.Mapping.get":
            return typed_dict_get_signature_callback
        elif fullname in TD_SETDEFAULT_NAMES:
            return typed_dict_setdefault_signature_callback
        elif fullname in TD_POP_NAMES:
            return typed_dict_pop_signature_callback
        elif fullname == "_ctypes.Array.__setitem__":
            return array_setitem_callback
        elif fullname == SINGLEDISPATCH_CALLABLE_CALL_METHOD:
            return call_singledispatch_function_callback
        elif fullname in TD_UPDATE_METHOD_NAMES:
            return typed_dict_update_signature_callback
        return None

    def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
        if fullname == "typing.Mapping.get":
            return typed_dict_get_callback
        elif fullname == "builtins.int.__pow__":
            return int_pow_callback
        elif fullname == "builtins.int.__neg__":
            return int_neg_callback
        elif fullname == "builtins.int.__pos__":
            return int_pos_callback
        elif fullname in ("builtins.tuple.__mul__", "builtins.tuple.__rmul__"):
            return tuple_mul_callback
        elif fullname in TD_SETDEFAULT_NAMES:
            return typed_dict_setdefault_callback
        elif fullname in TD_POP_NAMES:
            return typed_dict_pop_callback
        elif fullname in TD_DELITEM_NAMES:
            return typed_dict_delitem_callback
        elif fullname == "_ctypes.Array.__getitem__":
            return array_getitem_callback
        elif fullname == "_ctypes.Array.__iter__":
            return array_iter_callback
        elif fullname == SINGLEDISPATCH_REGISTER_METHOD:
            return singledispatch_register_callback
        elif fullname == SINGLEDISPATCH_REGISTER_CALLABLE_CALL_METHOD:
            return call_singledispatch_function_after_register_argument
        elif fullname == "functools.partial.__call__":
            return partial_call_callback
        return None

    def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
        if fullname == "_ctypes.Array.value":
            return array_value_callback
        elif fullname == "_ctypes.Array.raw":
            return array_raw_callback
        elif fullname in ENUM_NAME_ACCESS:
            return enum_name_callback
        elif fullname in ENUM_VALUE_ACCESS:
            return enum_value_callback
        return None

    def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
        # These dataclass and attrs hooks run in the main semantic analysis pass
        # and only tag known dataclasses/attrs classes, so that the second
        # hooks (in get_class_decorator_hook_2) can detect dataclasses/attrs classes
        # in the MRO.
        if fullname in dataclass_makers:
            return dataclass_tag_callback
        if (
            fullname in attr_class_makers
            or fullname in attr_dataclass_makers
            or fullname in attr_frozen_makers
            or fullname in attr_define_makers
        ):
            return attr_tag_callback
        return None

    def get_class_decorator_hook_2(
        self, fullname: str
    ) -> Callable[[ClassDefContext], bool] | None:
        if fullname in dataclass_makers:
            return dataclass_class_maker_callback
        elif fullname in functools_total_ordering_makers:
            return functools_total_ordering_maker_callback
        elif fullname in attr_class_makers:
            return attr_class_maker_callback
        elif fullname in attr_dataclass_makers:
            return partial(attr_class_maker_callback, auto_attribs_default=True)
        elif fullname in attr_frozen_makers:
            return partial(
                attr_class_maker_callback, auto_attribs_default=None, frozen_default=True
            )
        elif fullname in attr_define_makers:
            return partial(
                attr_class_maker_callback, auto_attribs_default=None, slots_default=True
            )
        return None


def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
    """Try to infer a better signature type for TypedDict.get.

    This is used to get better type context for the second argument that
    depends on a TypedDict value type.
    """
    signature = ctx.default_signature
    if (
        isinstance(ctx.type, TypedDictType)
        and len(ctx.args) == 2
        and len(ctx.args[0]) == 1
        and isinstance(ctx.args[0][0], StrExpr)
        and len(signature.arg_types) == 2
        and len(signature.variables) == 1
        and len(ctx.args[1]) == 1
    ):
        key = ctx.args[0][0].value
        value_type = get_proper_type(ctx.type.items.get(key))
        ret_type = signature.ret_type
        if value_type:
            default_arg = ctx.args[1][0]
            if (
                isinstance(value_type, TypedDictType)
                and isinstance(default_arg, DictExpr)
                and len(default_arg.items) == 0
            ):
                # Caller has empty dict {} as default for typed dict.
                value_type = value_type.copy_modified(required_keys=set())
            # Tweak the signature to include the value type as context. It's
            # only needed for type inference since there's a union with a type
            # variable that accepts everything.
            tv = signature.variables[0]
            assert isinstance(tv, TypeVarType)
            return signature.copy_modified(
                arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])],
                ret_type=ret_type,
            )
    return signature


def typed_dict_get_callback(ctx: MethodContext) -> Type:
    """Infer a precise return type for TypedDict.get with literal first argument."""
    if (
        isinstance(ctx.type, TypedDictType)
        and len(ctx.arg_types) >= 1
        and len(ctx.arg_types[0]) == 1
    ):
        keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
        if keys is None:
            return ctx.default_return_type

        output_types: list[Type] = []
        for key in keys:
            value_type = get_proper_type(ctx.type.items.get(key))
            if value_type is None:
                return ctx.default_return_type

            if len(ctx.arg_types) == 1:
                output_types.append(value_type)
            elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
                default_arg = ctx.args[1][0]
                if (
                    isinstance(default_arg, DictExpr)
                    and len(default_arg.items) == 0
                    and isinstance(value_type, TypedDictType)
                ):
                    # Special case '{}' as the default for a typed dict type.
                    output_types.append(value_type.copy_modified(required_keys=set()))
                else:
                    output_types.append(value_type)
                    output_types.append(ctx.arg_types[1][0])

        if len(ctx.arg_types) == 1:
            output_types.append(NoneType())

        return make_simplified_union(output_types)
    return ctx.default_return_type


def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:
    """Try to infer a better signature type for TypedDict.pop.

    This is used to get better type context for the second argument that
    depends on a TypedDict value type.
    """
    signature = ctx.default_signature
    str_type = ctx.api.named_generic_type("builtins.str", [])
    if (
        isinstance(ctx.type, TypedDictType)
        and len(ctx.args) == 2
        and len(ctx.args[0]) == 1
        and isinstance(ctx.args[0][0], StrExpr)
        and len(signature.arg_types) == 2
        and len(signature.variables) == 1
        and len(ctx.args[1]) == 1
    ):
        key = ctx.args[0][0].value
        value_type = ctx.type.items.get(key)
        if value_type:
            # Tweak the signature to include the value type as context. It's
            # only needed for type inference since there's a union with a type
            # variable that accepts everything.
            tv = signature.variables[0]
            assert isinstance(tv, TypeVarType)
            typ = make_simplified_union([value_type, tv])
            return signature.copy_modified(arg_types=[str_type, typ], ret_type=typ)
    return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])


def typed_dict_pop_callback(ctx: MethodContext) -> Type:
    """Type check and infer a precise return type for TypedDict.pop."""
    if (
        isinstance(ctx.type, TypedDictType)
        and len(ctx.arg_types) >= 1
        and len(ctx.arg_types[0]) == 1
    ):
        key_expr = ctx.args[0][0]
        keys = try_getting_str_literals(key_expr, ctx.arg_types[0][0])
        if keys is None:
            ctx.api.fail(
                message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
                key_expr,
                code=codes.LITERAL_REQ,
            )
            return AnyType(TypeOfAny.from_error)

        value_types = []
        for key in keys:
            if key in ctx.type.required_keys or key in ctx.type.readonly_keys:
                ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, key_expr)

            value_type = ctx.type.items.get(key)
            if value_type:
                value_types.append(value_type)
            else:
                ctx.api.msg.typeddict_key_not_found(ctx.type, key, key_expr)
                return AnyType(TypeOfAny.from_error)

        if len(ctx.args[1]) == 0:
            return make_simplified_union(value_types)
        elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
            return make_simplified_union([*value_types, ctx.arg_types[1][0]])
    return ctx.default_return_type


def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType:
    """Try to infer a better signature type for TypedDict.setdefault.

    This is used to get better type context for the second argument that
    depends on a TypedDict value type.
    """
    signature = ctx.default_signature
    str_type = ctx.api.named_generic_type("builtins.str", [])
    if (
        isinstance(ctx.type, TypedDictType)
        and len(ctx.args) == 2
        and len(ctx.args[0]) == 1
        and isinstance(ctx.args[0][0], StrExpr)
        and len(signature.arg_types) == 2
        and len(ctx.args[1]) == 1
    ):
        key = ctx.args[0][0].value
        value_type = ctx.type.items.get(key)
        if value_type:
            return signature.copy_modified(arg_types=[str_type, value_type])
    return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])


def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
    """Type check TypedDict.setdefault and infer a precise return type."""
    if (
        isinstance(ctx.type, TypedDictType)
        and len(ctx.arg_types) == 2
        and len(ctx.arg_types[0]) == 1
        and len(ctx.arg_types[1]) == 1
    ):
        key_expr = ctx.args[0][0]
        keys = try_getting_str_literals(key_expr, ctx.arg_types[0][0])
        if keys is None:
            ctx.api.fail(
                message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
                key_expr,
                code=codes.LITERAL_REQ,
            )
            return AnyType(TypeOfAny.from_error)

        assigned_readonly_keys = ctx.type.readonly_keys & set(keys)
        if assigned_readonly_keys:
            ctx.api.msg.readonly_keys_mutated(assigned_readonly_keys, context=key_expr)

        default_type = ctx.arg_types[1][0]
        default_expr = ctx.args[1][0]

        value_types = []
        for key in keys:
            value_type = ctx.type.items.get(key)

            if value_type is None:
                ctx.api.msg.typeddict_key_not_found(ctx.type, key, key_expr)
                return AnyType(TypeOfAny.from_error)

            # The signature_callback above can't always infer the right signature
            # (e.g. when the expression is a variable that happens to be a Literal str)
            # so we need to handle the check ourselves here and make sure the provided
            # default can be assigned to all key-value pairs we're updating.
            if not is_subtype(default_type, value_type):
                ctx.api.msg.typeddict_setdefault_arguments_inconsistent(
                    default_type, value_type, default_expr
                )
                return AnyType(TypeOfAny.from_error)

            value_types.append(value_type)

        return make_simplified_union(value_types)
    return ctx.default_return_type


def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
    """Type check TypedDict.__delitem__."""
    if (
        isinstance(ctx.type, TypedDictType)
        and len(ctx.arg_types) == 1
        and len(ctx.arg_types[0]) == 1
    ):
        key_expr = ctx.args[0][0]
        keys = try_getting_str_literals(key_expr, ctx.arg_types[0][0])
        if keys is None:
            ctx.api.fail(
                message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
                key_expr,
                code=codes.LITERAL_REQ,
            )
            return AnyType(TypeOfAny.from_error)

        for key in keys:
            if key in ctx.type.required_keys or key in ctx.type.readonly_keys:
                ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, key_expr)
            elif key not in ctx.type.items:
                ctx.api.msg.typeddict_key_not_found(ctx.type, key, key_expr)
    return ctx.default_return_type


_TP_DICT_MUTATING_METHODS: Final = frozenset({"update of TypedDict", "__ior__ of TypedDict"})


def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
    """Try to infer a better signature type for methods that update `TypedDict`.

    This includes: `TypedDict.update`, `TypedDict.__or__`, `TypedDict.__ror__`,
    and `TypedDict.__ior__`.
    """
    signature = ctx.default_signature
    if isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1:
        arg_type = get_proper_type(signature.arg_types[0])
        if not isinstance(arg_type, TypedDictType):
            return signature
        arg_type = arg_type.as_anonymous()
        arg_type = arg_type.copy_modified(required_keys=set())
        if ctx.args and ctx.args[0]:
            if signature.name in _TP_DICT_MUTATING_METHODS:
                # If we want to mutate this object in place, we need to set this flag,
                # it will trigger an extra check in TypedDict's checker.
                arg_type.to_be_mutated = True
            with ctx.api.msg.filter_errors(
                filter_errors=lambda name, info: info.code != codes.TYPEDDICT_READONLY_MUTATED,
                save_filtered_errors=True,
            ):
                inferred = get_proper_type(
                    ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type)
                )
            if arg_type.to_be_mutated:
                arg_type.to_be_mutated = False  # Done!
            possible_tds = []
            if isinstance(inferred, TypedDictType):
                possible_tds = [inferred]
            elif isinstance(inferred, UnionType):
                possible_tds = [
                    t
                    for t in get_proper_types(inferred.relevant_items())
                    if isinstance(t, TypedDictType)
                ]
            items = []
            for td in possible_tds:
                item = arg_type.copy_modified(
                    required_keys=(arg_type.required_keys | td.required_keys)
                    & arg_type.items.keys()
                )
                if not ctx.api.options.extra_checks:
                    item = item.copy_modified(item_names=list(td.items))
                items.append(item)
            if items:
                arg_type = make_simplified_union(items)
        return signature.copy_modified(arg_types=[arg_type])
    return signature


def int_pow_callback(ctx: MethodContext) -> Type:
    """Infer a more precise return type for int.__pow__."""
    # int.__pow__ has an optional modulo argument,
    # so we expect 2 argument positions
    if len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0:
        arg = ctx.args[0][0]
        if isinstance(arg, IntExpr):
            exponent = arg.value
        elif isinstance(arg, UnaryExpr) and arg.op == "-" and isinstance(arg.expr, IntExpr):
            exponent = -arg.expr.value
        else:
            # Right operand not an int literal or a negated literal -- give up.
            return ctx.default_return_type
        if exponent >= 0:
            return ctx.api.named_generic_type("builtins.int", [])
        else:
            return ctx.api.named_generic_type("builtins.float", [])
    return ctx.default_return_type


def int_neg_callback(ctx: MethodContext, multiplier: int = -1) -> Type:
    """Infer a more precise return type for int.__neg__ and int.__pos__.

    This is mainly used to infer the return type as LiteralType
    if the original underlying object is a LiteralType object.
    """
    if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None:
        value = ctx.type.last_known_value.value
        fallback = ctx.type.last_known_value.fallback
        if isinstance(value, int):
            if is_literal_type_like(ctx.api.type_context[-1]):
                return LiteralType(value=multiplier * value, fallback=fallback)
            else:
                return ctx.type.copy_modified(
                    last_known_value=LiteralType(
                        value=multiplier * value,
                        fallback=fallback,
                        line=ctx.type.line,
                        column=ctx.type.column,
                    )
                )
    elif isinstance(ctx.type, LiteralType):
        value = ctx.type.value
        fallback = ctx.type.fallback
        if isinstance(value, int):
            return LiteralType(value=multiplier * value, fallback=fallback)
    return ctx.default_return_type


def int_pos_callback(ctx: MethodContext) -> Type:
    """Infer a more precise return type for int.__pos__.

    This is identical to __neg__, except the value is not inverted.
    """
    return int_neg_callback(ctx, +1)


def tuple_mul_callback(ctx: MethodContext) -> Type:
    """Infer a more precise return type for tuple.__mul__ and tuple.__rmul__.

    This is used to return a specific sized tuple if multiplied by Literal int
    """
    if not isinstance(ctx.type, TupleType):
        return ctx.default_return_type

    arg_type = get_proper_type(ctx.arg_types[0][0])
    if isinstance(arg_type, Instance) and arg_type.last_known_value is not None:
        value = arg_type.last_known_value.value
        if isinstance(value, int):
            return ctx.type.copy_modified(items=ctx.type.items * value)
    elif isinstance(arg_type, LiteralType):
        value = arg_type.value
        if isinstance(value, int):
            return ctx.type.copy_modified(items=ctx.type.items * value)

    return ctx.default_return_type
