"""The code generator for modeling languages.

This is the code generator for the models used by Gaphor.

In order to work with the code generator, a model should follow some conventions:

* `Profile` packages are only for profiles (excluded from generation)
* A stereotype `simpleAttribute` can be defined, which converts an association
  to a `str` attribute
* A stereotype attribute `subsets` can be defined in case an association is derived

The coder first write the class declarations, including attributes and enumerations.
After that, associations are filled in, including derived unions and redefines.

Notes:
* Enumerations are classes ending with "Kind" or "Sort".

The code generator works by reading a model and the models it depends on.
It defines classes, attributes, enumerations and associations. Class names
are considered unique.
"""

from __future__ import annotations

import argparse
import contextlib
import keyword
import logging
import sys
import textwrap
from collections.abc import Iterable
from pathlib import Path

import gaphor.storage as storage
from gaphor import UML
from gaphor.codegen.override import Overrides
from gaphor.core.modeling import Base, ElementFactory
from gaphor.core.modeling.modelinglanguage import (
    CoreModelingLanguage,
    MockModelingLanguage,
    ModelingLanguage,
)
from gaphor.diagram.general.modelinglanguage import GeneralModelingLanguage
from gaphor.entrypoint import initialize
from gaphor.UML.modelinglanguage import UMLModelingLanguage

log = logging.getLogger(__name__)

header = textwrap.dedent(
    """\
    # This file is generated by coder.py. DO NOT EDIT!
    # {}: noqa: F401, E402, F811
    # fmt: off

    from __future__ import annotations

    import enum

    from gaphor.core.modeling.properties import (
        association,
        attribute as _attribute,
        derived,
        derivedunion,
        enumeration as _enumeration,
        redefine,
        relation_many,
        relation_one,
    )

    """.format("ruff")  # work around tooling triggers
)


def main(
    modelfile: str,
    supermodelfiles: list[tuple[str, str]] | None = None,
    overridesfile: str | None = None,
    outfile: str | None = None,
):
    logging.basicConfig()

    extra_langs = (
        [
            load_modeling_language(lang)
            for lang, _ in supermodelfiles
            if lang not in ("Core", "general", "UML")
        ]
        if supermodelfiles
        else []
    )
    modeling_language = MockModelingLanguage(
        *(
            [
                CoreModelingLanguage(),
                GeneralModelingLanguage(),
                UMLModelingLanguage(),
            ]
            + extra_langs
        )
    )

    model = load_model(modelfile, modeling_language)
    super_models = (
        {
            lang: (load_modeling_language(lang), load_model(f, modeling_language))
            for lang, f in supermodelfiles
        }
        if supermodelfiles
        else {}
    )
    overrides = Overrides(overridesfile) if overridesfile else None

    with (
        open(outfile, "w", encoding="utf-8")
        if outfile
        else contextlib.nullcontext(sys.stdout) as out
    ):
        for line in coder(model, super_models, overrides):
            print(line, file=out)


def load_model(modelfile: str, modeling_language: ModelingLanguage) -> ElementFactory:
    element_factory = ElementFactory()
    with open(modelfile, encoding="utf-8") as file_obj:
        storage.load(
            file_obj,
            element_factory,
            modeling_language,
        )

    resolve_attribute_type_values(element_factory)

    return element_factory


def load_modeling_language(lang) -> ModelingLanguage:
    return initialize("gaphor.modelinglanguages", [lang])[lang]


def coder(
    model: ElementFactory,
    super_models: dict[str, tuple[ModelingLanguage, ElementFactory]],
    overrides: Overrides | None,
) -> Iterable[str]:
    yield header
    if overrides and overrides.header:
        yield overrides.header

    classes = list(
        order_classes(
            c
            for c in model.select(UML.Class)
            if not is_enumeration(c)
            and not is_simple_type(c)
            and not is_in_profile(c)
            and not is_tilde_type(c)
        )
    )

    already_imported = set()
    for c in classes:
        if not any(bases(c)):
            element_type, cls = in_super_model(c, super_models)
            if element_type and cls:
                # always alias imported names
                c.name = f"_{c.name}"
                line = f"from {element_type.__module__} import {element_type.__name__} as {c.name}"
                yield line
                already_imported.add(line)
                continue

    yield ""
    yield ""

    for enum in sorted((model.select(UML.Enumeration)), key=lambda e: e.name):
        yield from enumeration(enum)

    for c in classes:
        if c.name.startswith("_"):
            # imported from super model
            continue

        if overrides and overrides.has_override(c.name):
            yield overrides.get_override(c.name)
            continue

        yield class_declaration(c)
        if properties := list(variables(c, overrides)):
            yield from (f"    {p}" for p in properties)
        else:
            yield "    pass"
        yield ""
        yield ""

    for c in classes:
        yield from operations(c, overrides)

    yield ""

    for c in classes:
        yield from associations(c, overrides)
        for line in subsets(c, super_models):
            if line.startswith("from "):
                if line not in already_imported:
                    yield line
                already_imported.add(line)
            else:
                yield line


def enumeration(enum: UML.Enumeration):
    yield f"class {enum.name}(enum.StrEnum):"
    for literal in enum.ownedLiteral:
        name = literal.name
        if keyword.iskeyword(name):
            name = f"{name}_"
        yield f'    {name} = "{literal.name}"'
    yield ""
    yield ""


def class_declaration(class_: UML.Class):
    base_classes = ", ".join(
        c.name
        for c in sorted(
            bases(class_),
            key=lambda c: c.name,
        )
    )
    return f"class {class_.name}({base_classes}):"


def variables(class_: UML.Class, overrides: Overrides | None = None):
    if class_.ownedAttribute:
        a: UML.Property
        for a in sorted(class_.ownedAttribute, key=lambda a: a.name or ""):
            if is_extension_end(a):
                continue

            full_name = f"{class_.name}.{a.name}"
            if overrides and overrides.has_override(full_name):
                yield f"{a.name}: {overrides.get_type(full_name)}"
            elif a.isDerived and not a.type:
                log.warning(f"Derived attribute {full_name} has no implementation.")
            elif a.typeValue:
                yield f'{a.name}: _attribute[{a.typeValue}] = _attribute("{a.name}", {a.typeValue}{default_value(a)})'
            elif is_enumeration(a.type):
                assert isinstance(a.type, UML.Enumeration)
                default = (
                    a.defaultValue.value
                    if isinstance(a.defaultValue, UML.LiteralString)
                    and a.defaultValue.value
                    else a.type.ownedLiteral[0].name
                )
                if keyword.iskeyword(default):
                    default = f"{default}_"
                yield f'{a.name} = _enumeration("{a.name}", {a.type.name}, {a.type.name}.{default})'
            elif a.type:
                mult = (
                    "one"
                    if UML.recipes.get_multiplicity_upper_value_as_string(a) == "1"
                    else "many"
                )
                comment = "  # type: ignore[assignment]" if is_reassignment(a) else ""
                yield f"{a.name}: relation_{mult}[{a.type.name}]{comment}"
            else:
                assert isinstance(a.owner, Base)
                raise ValueError(
                    f"{a.name}: {a.type} can not be written; owner={a.owner.name}"  # type: ignore[attr-defined]
                )

    if class_.ownedOperation:
        for o in sorted(class_.ownedOperation, key=lambda a: a.name or ""):
            full_name = f"{class_.name}.{o.name}"
            if overrides and overrides.has_override(full_name):
                yield f"{o.name}: {overrides.get_type(full_name)}"
            else:
                log.warning(f"Operation {full_name} has no implementation")


def associations(
    c: UML.Class,
    overrides: Overrides | None = None,
):
    redefinitions = []
    for a in c.ownedAttribute:
        full_name = f"{c.name}.{a.name}"
        if overrides and overrides.has_override(full_name):
            yield overrides.get_override(full_name)
        elif (
            not a.type
            or is_simple_type(a.type)
            or is_enumeration(a.type)
            or is_extension_end(a)
        ):
            continue
        elif redefines(a):
            redefinitions.append(
                f'{full_name} = redefine({c.name}, "{a.name}", {a.type.name}, {redefines(a)}{opposite(a)})'
            )
        elif a.isDerived:
            yield f'{full_name} = derivedunion("{a.name}", {a.type.name}{lower(a)}{upper(a)})'
        elif not a.name:
            raise ValueError(f"Unnamed attribute: {full_name} ({a.association})")
        else:
            yield f'{full_name} = association("{a.name}", {a.type.name}{lower(a)}{upper(a)}{composite(a)}{opposite(a)})'

    yield from redefinitions


def subsets(
    c: UML.Class,
    super_models: dict[str, tuple[ModelingLanguage, ElementFactory]],
):
    for a in c.ownedAttribute:
        if (
            not a.type
            or is_simple_type(a.type)
            or is_enumeration(a.type)
            or is_extension_end(a)
        ):
            continue
        for slot in a.appliedStereotype[:].slot:
            if slot.definingFeature.name != "subsets":
                continue

            full_name = f"{c.name}.{a.name}"
            raw_slot_value = UML.recipes.get_slot_value(slot)
            slotValue = raw_slot_value if isinstance(raw_slot_value, str) else ""
            for value in slotValue.split(","):
                element_type, d = superset_attribute(c, value.strip(), super_models)
                if d:
                    assert isinstance(d.owner, UML.NamedElement)
                    if element_type:
                        owner_name = f"_{d.owner.name}"
                        # Line will be filtered out if it's already imported.
                        yield f"from {element_type.__module__} import {d.owner.name} as {owner_name}"
                    else:
                        owner_name = d.owner.name
                    yield f"{owner_name}.{d.name}.add({full_name})  # type: ignore[attr-defined]"
                elif not d:
                    log.warning(
                        f"{full_name} wants to subset {value.strip()}, but it is not defined"
                    )
                else:
                    log.warning(
                        f"{full_name} wants to subset {value.strip()}, but it is not a derived union"
                    )


def operations(c: UML.Class, overrides: Overrides | None = None):
    if c.ownedOperation:
        for o in sorted(c.ownedOperation, key=lambda a: a.name or ""):
            full_name = f"{c.name}.{o.name}"
            if overrides and overrides.has_override(full_name):
                yield overrides.get_override(full_name)


def default_value(a) -> str:
    if a.defaultValue:
        if a.typeValue == "int":
            if isinstance(
                a.defaultValue,
                UML.LiteralString
                | UML.LiteralInteger
                | UML.LiteralUnlimitedNatural
                | UML.LiteralBoolean,
            ):
                defaultValue = UML.recipes.get_literal_value_as_string(a.defaultValue)
            else:
                defaultValue = a.defaultValue.title()
        elif a.typeValue == "str":
            if isinstance(
                a.defaultValue,
                UML.LiteralString
                | UML.LiteralInteger
                | UML.LiteralUnlimitedNatural
                | UML.LiteralBoolean,
            ):
                defaultValue = UML.recipes.get_literal_value_as_string(a.defaultValue)
            else:
                defaultValue = f'"{a.defaultValue}"'
        elif a.typeValue == "bool":
            if isinstance(a.defaultValue, UML.LiteralBoolean | UML.LiteralString):
                defaultValue = UML.recipes.get_literal_value_as_string(a.defaultValue)
            else:
                defaultValue = a.defaultValue
            if defaultValue == "true":
                defaultValue = "True"
            elif defaultValue == "false":
                defaultValue = "False"
        else:
            raise ValueError(
                f"Unknown default value type: {a.owner.name}.{a.name}: {a.typeValue} = {a.defaultValue}"
            )

        return f", default={defaultValue}"
    return ""


def lower(a):
    lowerValue = ""
    if isinstance(a.lowerValue, UML.LiteralInteger):
        if (
            a.lowerValue.value
            and a.lowerValue.value is not None
            and a.lowerValue.value != 0
        ):
            lowerValue = str(a.lowerValue.value)
    else:
        if a.lowerValue is not None and a.lowerValue != "0":
            lowerValue = a.lowerValue
    return "" if lowerValue == "" else f", lower={lowerValue}"


def upper(a):
    upperValue = ""
    if isinstance(a.upperValue, UML.LiteralUnlimitedNatural):
        if (
            a.upperValue.value
            and a.upperValue.value is not None
            and a.upperValue.value != "*"
        ):
            upperValue = str(int(a.upperValue.value))
    else:
        if a.upperValue is not None and a.upperValue != "*":
            upperValue = a.upperValue
    return "" if upperValue == "" else f", upper={upperValue}"


def composite(a):
    return ", composite=True" if a.aggregation == "composite" else ""


def opposite(a):
    return (
        f', opposite="{a.opposite.name}"'
        if a.opposite and a.opposite.name and a.opposite.class_
        else ""
    )


def order_classes(classes: Iterable[UML.Class]) -> Iterable[UML.Class]:
    seen_classes = set()

    def order(c):
        if c not in seen_classes:
            for b in bases(c):
                yield from order(b)
            yield c
            seen_classes.add(c)

    for c in classes:
        yield from order(c)


def bases(c: UML.Class) -> Iterable[UML.Class]:
    for g in c.generalization:
        assert isinstance(g.general, UML.Class)
        yield g.general

    for a in c.ownedAttribute:
        if a.association and a.name == "baseClass":
            yield a.association.ownedEnd.class_  # type: ignore[attr-defined]


def is_enumeration(c: UML.Type) -> bool:
    return isinstance(c, UML.Enumeration)


def is_simple_type(c: UML.Type) -> bool:
    return any(
        s.name == "SimpleAttribute" for s in UML.recipes.get_applied_stereotypes(c)
    ) or any(is_simple_type(g.general) for g in c.generalization)  # type: ignore[attr-defined]


def is_tilde_type(c: UML.Type) -> bool:
    return c and c.name and c.name.startswith("~")  # type: ignore[return-value]


def is_extension_end(a: UML.Property):
    return isinstance(a.association, UML.Extension)


def is_reassignment(a: UML.Property) -> bool:
    def test(c: UML.Class):
        for attr in c.ownedAttribute:
            if attr.name == a.name:
                return True
        return any(test(base) for base in bases(c))

    return any(test(base) for base in bases(a.owner))  # type:ignore[arg-type]


def is_in_profile(c: UML.Classifier) -> bool:
    def test(p: UML.Package):
        return isinstance(p, UML.Profile) or (p.owningPackage and test(p.owningPackage))

    return test(c.owningPackage)  # type: ignore[no-any-return]


def is_in_toplevel_package(c: UML.Class, package_name: str) -> bool:
    def test(p: UML.Package):
        return (not p.owningPackage and p.name == package_name) or (
            p.owningPackage and test(p.owningPackage)
        )

    return test(c.owningPackage)  # type: ignore[no-any-return]


def redefines(a: UML.Property) -> str | None:
    # TODO: look up element name and add underscore if needed.
    # maybe resolve redefines before we start writing?
    # Redefine is the only one where
    return next(
        (
            UML.recipes.get_slot_value(slot)
            for slot in a.appliedStereotype[:].slot
            if slot.definingFeature.name == "redefines"
        ),
        None,
    )


def superset_attribute(
    c: UML.Class,
    name: str,
    super_models: dict[str, tuple[ModelingLanguage, ElementFactory]],
) -> tuple[type[Base] | None, UML.Property | None]:
    """Lookup an attribute from a super type."""
    a: UML.Property | None
    for a in c.ownedAttribute:
        if a.name == name:
            return None, a

    for base in bases(c):
        element_type, a = superset_attribute(base, name, super_models)
        if a:
            return element_type, a

    element_type, super_class = in_super_model(c, super_models)
    if super_class and c is not super_class:
        assert isinstance(super_class, UML.Class)
        _, a = superset_attribute(super_class, name, super_models)
        return element_type, a

    return None, None


def in_super_model(
    type: UML.Type, super_models: dict[str, tuple[ModelingLanguage, ElementFactory]]
) -> tuple[type[Base], UML.Classifier] | tuple[None, None]:
    ns = ".".join(type.owningPackage.qualifiedName)
    if ns not in super_models:
        return None, None
    modeling_language, factory = super_models[ns]

    # type.name may have been prefixed by an underscore, if it's imported
    name = type.name[1:] if type.name.startswith("_") else type.name
    cls: UML.Classifier
    for cls in factory.select(  # type: ignore[assignment]
        lambda e: isinstance(e, UML.Classifier)
        and e.name == name
        and ".".join(e.owningPackage.qualifiedName) not in super_models
    ):
        element_type = modeling_language.lookup_element(cls.name, ns=ns)
        assert element_type, (
            f"Type {ns}.{name} found in model, but not in generated model"
        )
        return element_type, cls

    raise AssertionError(f"Type {ns}.{name} found in model, but not in generated model")


def resolve_attribute_type_values(element_factory: ElementFactory) -> None:
    """Some model updates that are hard to do from Gaphor itself."""
    for prop in element_factory.select(UML.Property):
        if prop.typeValue in ("String", "str", "object"):
            prop.typeValue = "str"
        elif prop.typeValue in ("Boolean", "bool"):
            prop.typeValue = "bool"
        elif prop.typeValue in (
            "Integer",
            "int",
        ):
            prop.typeValue = "int"
        elif prop.typeValue == "UnlimitedNatural":
            pass
        elif c := next(
            element_factory.select(
                lambda e: isinstance(e, UML.Class | UML.Enumeration)
                and e.name == prop.typeValue  # noqa: B023
            ),
            None,
        ):
            prop.type = c  # type: ignore[assignment]
            del prop.typeValue
            prop.aggregation = "composite"

        if prop.type and is_simple_type(prop.type):
            prop.typeValue = "str"
            del prop.type

        if not prop.type and prop.typeValue not in (
            "str",
            "int",
            "bool",
            "UnlimitedNatural",
            None,
        ):
            raise ValueError(f"Property value type {prop.typeValue} can not be found")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("modelfile", type=Path, help="Gaphor model filename")
    parser.add_argument(
        "-o", dest="outfile", type=Path, help="Python data model filename"
    )
    parser.add_argument("-r", dest="overridesfile", type=Path, help="Override filename")
    parser.add_argument(
        "-s",
        dest="supermodelfiles",
        type=str,
        action="append",
        help="Reference to dependent model file (e.g. UML:models/UML.gaphor)",
    )

    args = parser.parse_args()
    supermodelfiles = (
        [s.split(":") for s in args.supermodelfiles] if args.supermodelfiles else []
    )

    main(args.modelfile, supermodelfiles, args.overridesfile, args.outfile)
