"""Intermediate representation of functions."""

from __future__ import annotations

import inspect
from collections.abc import Sequence
from typing import Final

from mypy.nodes import ARG_POS, ArgKind, Block, FuncDef
from mypyc.common import BITMAP_BITS, JsonDict, bitmap_name, get_id_from_name, short_id_from_name
from mypyc.ir.ops import (
    Assign,
    AssignMulti,
    BasicBlock,
    Box,
    ControlOp,
    DeserMaps,
    Float,
    Integer,
    LoadAddress,
    LoadLiteral,
    Register,
    TupleSet,
    Value,
)
from mypyc.ir.rtypes import (
    RType,
    bitmap_rprimitive,
    deserialize_type,
    is_bool_rprimitive,
    is_none_rprimitive,
)
from mypyc.namegen import NameGenerator


class RuntimeArg:
    """Description of a function argument in IR.

    Argument kind is one of ARG_* constants defined in mypy.nodes.
    """

    def __init__(
        self, name: str, typ: RType, kind: ArgKind = ARG_POS, pos_only: bool = False
    ) -> None:
        self.name = name
        self.type = typ
        self.kind = kind
        self.pos_only = pos_only

    @property
    def optional(self) -> bool:
        return self.kind.is_optional()

    def __repr__(self) -> str:
        return "RuntimeArg(name={}, type={}, optional={!r}, pos_only={!r})".format(
            self.name, self.type, self.optional, self.pos_only
        )

    def serialize(self) -> JsonDict:
        return {
            "name": self.name,
            "type": self.type.serialize(),
            "kind": int(self.kind.value),
            "pos_only": self.pos_only,
        }

    @classmethod
    def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> RuntimeArg:
        return RuntimeArg(
            data["name"],
            deserialize_type(data["type"], ctx),
            ArgKind(data["kind"]),
            data["pos_only"],
        )


class FuncSignature:
    """Signature of a function in IR."""

    # TODO: Track if method?

    def __init__(self, args: Sequence[RuntimeArg], ret_type: RType) -> None:
        self.args = tuple(args)
        self.ret_type = ret_type
        # Bitmap arguments are use to mark default values for arguments that
        # have types with overlapping error values.
        self.num_bitmap_args = num_bitmap_args(self.args)
        if self.num_bitmap_args:
            extra = [
                RuntimeArg(bitmap_name(i), bitmap_rprimitive, pos_only=True)
                for i in range(self.num_bitmap_args)
            ]
            self.args = self.args + tuple(reversed(extra))

    def real_args(self) -> tuple[RuntimeArg, ...]:
        """Return arguments without any synthetic bitmap arguments."""
        if self.num_bitmap_args:
            return self.args[: -self.num_bitmap_args]
        return self.args

    def bound_sig(self) -> FuncSignature:
        if self.num_bitmap_args:
            return FuncSignature(self.args[1 : -self.num_bitmap_args], self.ret_type)
        else:
            return FuncSignature(self.args[1:], self.ret_type)

    def __repr__(self) -> str:
        return f"FuncSignature(args={self.args!r}, ret={self.ret_type!r})"

    def serialize(self) -> JsonDict:
        if self.num_bitmap_args:
            args = self.args[: -self.num_bitmap_args]
        else:
            args = self.args
        return {"args": [t.serialize() for t in args], "ret_type": self.ret_type.serialize()}

    @classmethod
    def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> FuncSignature:
        return FuncSignature(
            [RuntimeArg.deserialize(arg, ctx) for arg in data["args"]],
            deserialize_type(data["ret_type"], ctx),
        )


def num_bitmap_args(args: tuple[RuntimeArg, ...]) -> int:
    n = 0
    for arg in args:
        if arg.type.error_overlap and arg.kind.is_optional():
            n += 1
    return (n + (BITMAP_BITS - 1)) // BITMAP_BITS


FUNC_NORMAL: Final = 0
FUNC_STATICMETHOD: Final = 1
FUNC_CLASSMETHOD: Final = 2


class FuncDecl:
    """Declaration of a function in IR (without body or implementation).

    A function can be a regular module-level function, a method, a
    static method, a class method, or a property getter/setter.
    """

    def __init__(
        self,
        name: str,
        class_name: str | None,
        module_name: str,
        sig: FuncSignature,
        kind: int = FUNC_NORMAL,
        is_prop_setter: bool = False,
        is_prop_getter: bool = False,
        implicit: bool = False,
        internal: bool = False,
    ) -> None:
        self.name = name
        self.class_name = class_name
        self.module_name = module_name
        self.sig = sig
        self.kind = kind
        self.is_prop_setter = is_prop_setter
        self.is_prop_getter = is_prop_getter
        if class_name is None:
            self.bound_sig: FuncSignature | None = None
        else:
            if kind == FUNC_STATICMETHOD:
                self.bound_sig = sig
            else:
                self.bound_sig = sig.bound_sig()

        # If True, not present in the mypy AST and must be synthesized during irbuild
        # Currently only supported for property getters/setters
        self.implicit = implicit

        # If True, only direct C level calls are supported (no wrapper function)
        self.internal = internal

        # This is optional because this will be set to the line number when the corresponding
        # FuncIR is created
        self._line: int | None = None

    @property
    def line(self) -> int:
        assert self._line is not None
        return self._line

    @line.setter
    def line(self, line: int) -> None:
        self._line = line

    @property
    def id(self) -> str:
        assert self.line is not None
        return get_id_from_name(self.name, self.fullname, self.line)

    @staticmethod
    def compute_shortname(class_name: str | None, name: str) -> str:
        return class_name + "." + name if class_name else name

    @property
    def shortname(self) -> str:
        return FuncDecl.compute_shortname(self.class_name, self.name)

    @property
    def fullname(self) -> str:
        return self.module_name + "." + self.shortname

    def cname(self, names: NameGenerator) -> str:
        partial_name = short_id_from_name(self.name, self.shortname, self._line)
        return names.private_name(self.module_name, partial_name)

    def serialize(self) -> JsonDict:
        return {
            "name": self.name,
            "class_name": self.class_name,
            "module_name": self.module_name,
            "sig": self.sig.serialize(),
            "kind": self.kind,
            "is_prop_setter": self.is_prop_setter,
            "is_prop_getter": self.is_prop_getter,
            "implicit": self.implicit,
            "internal": self.internal,
        }

    # TODO: move this to FuncIR?
    @staticmethod
    def get_id_from_json(func_ir: JsonDict) -> str:
        """Get the id from the serialized FuncIR associated with this FuncDecl"""
        decl = func_ir["decl"]
        shortname = FuncDecl.compute_shortname(decl["class_name"], decl["name"])
        fullname = decl["module_name"] + "." + shortname
        return get_id_from_name(decl["name"], fullname, func_ir["line"])

    @classmethod
    def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> FuncDecl:
        return FuncDecl(
            data["name"],
            data["class_name"],
            data["module_name"],
            FuncSignature.deserialize(data["sig"], ctx),
            data["kind"],
            data["is_prop_setter"],
            data["is_prop_getter"],
            data["implicit"],
            data["internal"],
        )


class FuncIR:
    """Intermediate representation of a function with contextual information.

    Unlike FuncDecl, this includes the IR of the body (basic blocks).
    """

    def __init__(
        self,
        decl: FuncDecl,
        arg_regs: list[Register],
        blocks: list[BasicBlock],
        line: int = -1,
        traceback_name: str | None = None,
    ) -> None:
        # Declaration of the function, including the signature
        self.decl = decl
        # Registers for all the arguments to the function
        self.arg_regs = arg_regs
        # Body of the function
        self.blocks = blocks
        self.decl.line = line
        # The name that should be displayed for tracebacks that
        # include this function. Function will be omitted from
        # tracebacks if None.
        self.traceback_name = traceback_name

    @property
    def line(self) -> int:
        return self.decl.line

    @property
    def args(self) -> Sequence[RuntimeArg]:
        return self.decl.sig.args

    @property
    def ret_type(self) -> RType:
        return self.decl.sig.ret_type

    @property
    def class_name(self) -> str | None:
        return self.decl.class_name

    @property
    def sig(self) -> FuncSignature:
        return self.decl.sig

    @property
    def name(self) -> str:
        return self.decl.name

    @property
    def fullname(self) -> str:
        return self.decl.fullname

    @property
    def id(self) -> str:
        return self.decl.id

    @property
    def internal(self) -> bool:
        return self.decl.internal

    def cname(self, names: NameGenerator) -> str:
        return self.decl.cname(names)

    def __repr__(self) -> str:
        if self.class_name:
            return f"<FuncIR {self.class_name}.{self.name}>"
        else:
            return f"<FuncIR {self.name}>"

    def serialize(self) -> JsonDict:
        # We don't include blocks in the serialized version
        return {
            "decl": self.decl.serialize(),
            "line": self.line,
            "traceback_name": self.traceback_name,
        }

    @classmethod
    def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> FuncIR:
        return FuncIR(
            FuncDecl.deserialize(data["decl"], ctx), [], [], data["line"], data["traceback_name"]
        )


INVALID_FUNC_DEF: Final = FuncDef("<INVALID_FUNC_DEF>", [], Block([]))


def all_values(args: list[Register], blocks: list[BasicBlock]) -> list[Value]:
    """Return the set of all values that may be initialized in the blocks.

    This omits registers that are only read.
    """
    values: list[Value] = list(args)
    seen_registers = set(args)

    for block in blocks:
        for op in block.ops:
            if not isinstance(op, ControlOp):
                if isinstance(op, (Assign, AssignMulti)):
                    if op.dest not in seen_registers:
                        values.append(op.dest)
                        seen_registers.add(op.dest)
                elif op.is_void:
                    continue
                else:
                    # If we take the address of a register, it might get initialized.
                    if (
                        isinstance(op, LoadAddress)
                        and isinstance(op.src, Register)
                        and op.src not in seen_registers
                    ):
                        values.append(op.src)
                        seen_registers.add(op.src)
                    values.append(op)

    return values


def all_values_full(args: list[Register], blocks: list[BasicBlock]) -> list[Value]:
    """Return set of all values that are initialized or accessed."""
    values: list[Value] = list(args)
    seen_registers = set(args)

    for block in blocks:
        for op in block.ops:
            for source in op.sources():
                # Look for uninitialized registers that are accessed. Ignore
                # non-registers since we don't allow ops outside basic blocks.
                if isinstance(source, Register) and source not in seen_registers:
                    values.append(source)
                    seen_registers.add(source)
            if not isinstance(op, ControlOp):
                if isinstance(op, (Assign, AssignMulti)):
                    if op.dest not in seen_registers:
                        values.append(op.dest)
                        seen_registers.add(op.dest)
                elif op.is_void:
                    continue
                else:
                    values.append(op)

    return values


_ARG_KIND_TO_INSPECT: Final = {
    ArgKind.ARG_POS: inspect.Parameter.POSITIONAL_OR_KEYWORD,
    ArgKind.ARG_OPT: inspect.Parameter.POSITIONAL_OR_KEYWORD,
    ArgKind.ARG_STAR: inspect.Parameter.VAR_POSITIONAL,
    ArgKind.ARG_NAMED: inspect.Parameter.KEYWORD_ONLY,
    ArgKind.ARG_STAR2: inspect.Parameter.VAR_KEYWORD,
    ArgKind.ARG_NAMED_OPT: inspect.Parameter.KEYWORD_ONLY,
}

# Sentinel indicating a value that cannot be represented in a text signature.
_NOT_REPRESENTABLE = object()


def get_text_signature(fn: FuncIR, *, bound: bool = False) -> str | None:
    """Return a text signature in CPython's internal doc format, or None
    if the function's signature cannot be represented.
    """
    parameters = []
    mark_self = (fn.class_name is not None) and (fn.decl.kind != FUNC_STATICMETHOD) and not bound
    sig = fn.decl.bound_sig if bound and fn.decl.bound_sig is not None else fn.decl.sig
    # Pre-scan for end of positional-only parameters.
    # This is needed to handle signatures like 'def foo(self, __x)', where mypy
    # currently sees 'self' as being positional-or-keyword and '__x' as positional-only.
    pos_only_idx = -1
    for idx, arg in enumerate(sig.args):
        if arg.pos_only and arg.kind in (ArgKind.ARG_POS, ArgKind.ARG_OPT):
            pos_only_idx = idx
    for idx, arg in enumerate(sig.args):
        if arg.name.startswith(("__bitmap", "__mypyc")):
            continue
        kind = (
            inspect.Parameter.POSITIONAL_ONLY
            if idx <= pos_only_idx
            else _ARG_KIND_TO_INSPECT[arg.kind]
        )
        default: object = inspect.Parameter.empty
        if arg.optional:
            default = _find_default_argument(arg.name, fn.blocks)
            if default is _NOT_REPRESENTABLE:
                # This default argument cannot be represented in a __text_signature__
                return None

        curr_param = inspect.Parameter(arg.name, kind, default=default)
        parameters.append(curr_param)
        if mark_self:
            # Parameter.__init__/Parameter.replace do not accept $
            curr_param._name = f"${arg.name}"  # type: ignore[attr-defined]
            mark_self = False
    return f"{fn.name}{inspect.Signature(parameters)}"


def _find_default_argument(name: str, blocks: list[BasicBlock]) -> object:
    # Find assignment inserted by gen_arg_defaults. Assumed to be the first assignment.
    for block in blocks:
        for op in block.ops:
            if isinstance(op, Assign) and op.dest.name == name:
                return _extract_python_literal(op.src)
    return _NOT_REPRESENTABLE


def _extract_python_literal(value: Value) -> object:
    if isinstance(value, Integer):
        if is_none_rprimitive(value.type):
            return None
        val = value.numeric_value()
        if is_bool_rprimitive(value.type):
            return bool(val)
        return val
    elif isinstance(value, Float):
        return value.value
    elif isinstance(value, LoadLiteral):
        return value.value
    elif isinstance(value, Box):
        return _extract_python_literal(value.src)
    elif isinstance(value, TupleSet):
        items = tuple(_extract_python_literal(item) for item in value.items)
        if any(itm is _NOT_REPRESENTABLE for itm in items):
            return _NOT_REPRESENTABLE
        return items
    return _NOT_REPRESENTABLE
