#! /usr/bin/env python3

# Import the fastest jsons library available
try:
    import orjson as json
except ImportError:
    try:
        import ujson as json
    except ImportError:
        try:
            import simplejson as json
        except ImportError:
            import json

import argparse
import sys
from collections import namedtuple
import pathlib
import lzma
import sqlite3
from functools import cache


# Version of the profiling files generated by preCICE
RUN_FILE_VERSION: int = 2
# Version of the merged profiling data generated by merge
MERGED_FILE_VERSION: int = 1


def mergedDict(dict1, dict2):
    merged = dict1.copy()
    merged.update(dict2)
    return merged


def warning(message, filename=None):
    prefix = f"{str(filename)}: " if filename else ""
    print(f"{prefix}\033[33mwarning:\033[0m {message}")


def readJSON(filename: pathlib.Path):
    assert filename.suffix == ".json"
    content = filename.read_text()
    try:
        return json.loads(content)  # try direct
    except:
        warning("File damaged. Attempting to terminate truncated event file.", filename)
        content += "]}"
        try:
            return json.loads(content)  # try terminated
        except:
            warning("Unable to load critically damaged file.", filename)
            return {}  # give up


def expandTXTRecord(s: str):
    parts = s[1:].rstrip().split(":")
    match s[0]:
        case "N":
            eid, name = int(parts[0]), parts[1]
            return {"et": "n", "eid": eid, "en": name}
        case "B":
            eid, ts = map(int, parts)
            return {"et": "b", "eid": eid, "ts": ts}
        case "E":
            eid, ts = map(int, parts)
            return {"et": "e", "eid": eid, "ts": ts}
        case "D":
            eid, ts, dn, dv = map(int, parts)
            return {"et": "d", "eid": eid, "ts": ts, "dn": dn, "dv": dv}
    assert False


def readTXT(filename: pathlib.Path):
    with filename.open("rb") as file:
        meta = json.loads(file.readline())
        assert "compression" in meta
        stream = lzma.LZMAFile(file) if meta["compression"] else file
        events = [expandTXTRecord(line.decode()) for line in stream]
        return {
            "meta": meta,
            "events": events,
        }


def readTimestamp(filename: pathlib.Path):
    if filename.suffix == ".json":
        meta = readJSON(filename)["meta"]
        return int(meta["unix_us"])

    assert filename.suffix == ".txt"
    with filename.open("rb") as file:
        meta = json.loads(file.readline())
        return int(meta["unix_us"])


def createProfilingDB(con: sqlite3.Connection) -> sqlite3.Cursor:
    cur = con.cursor()
    cur.execute(
        "CREATE TABLE names(eid INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT UNIQUE)"
    )

    # cur.executemany("INSERT INTO names VALUES(?, ?)", [ ( int(id), name ) for id, name in names.items() ])
    # con.commit()

    cur.execute(
        "CREATE TABLE participants(pid INTEGER PRIMARY KEY AUTOINCREMENT , name TEXT UNIQUE , size INT)"
    )
    # cur.executemany(f"INSERT INTO participants VALUES(?, ?)", enumerate(data["events"].keys()))

    cur.execute(
        "CREATE TABLE events(pid TINYINT , rank INT , eid SMALLINT , ts INT , dur INT , data TXT , FOREIGN KEY(pid) REFERENCES participants(pid) FOREIGN KEY(eid) REFERENCES names(eid))"
    )

    # Add a view for easier iteration
    cur.execute(
        """
                CREATE VIEW full_events
                AS SELECT p.name AS participant, e.rank AS rank, p.size AS size, n.name AS event, e.ts AS ts, e.dur AS dur, e.data AS data
                FROM events e
                INNER JOIN names n ON e.eid = n.eid INNER JOIN participants p ON e.pid = p.pid
                """
    )
    con.commit()
    return cur


@cache
def addOrFetchParticipant(cur: sqlite3.Cursor, name: str, size: int):
    cur.execute(
        "INSERT OR IGNORE INTO participants (name, size) VALUES (?, ?)", (name, size)
    )
    cur.execute("SELECT pid FROM participants WHERE name = ?", (name,))
    return int(cur.fetchone()[0])


@cache
def addFetchName(cur: sqlite3.Cursor, name: str):
    # Insert the event name if not yet known
    cur.execute("INSERT OR IGNORE INTO names(name) VALUES (?)", (name,))
    cur.execute("SELECT eid FROM names WHERE name = ?", (name,))
    return cur.fetchone()[0]


def insertEvent(cur, pid, rank, name, ts, dur, data):
    eid = addFetchName(cur, name)

    # Insert the event name if possible
    cur.execute(
        "INSERT INTO events VALUES (?, ?, ?, ?, ?, ?)",
        (pid, rank, eid, ts, dur, (json.dumps(data) if data else None)),
    )


def alignEvents(con: sqlite3.Connection):
    """Aligns passed events of multiple ranks and or participants.
    All ranks of a participant align at initialization, ensured by a barrier in preCICE.
    Primary ranks of all participants align after successfully establishing primary connections.
    """

    cur = con.cursor()

    print("Align participant ranks")
    # We need the min of intra com syncs grouped by participant
    cur.execute(
        """
            UPDATE events
            SET ts = ts + delta
            FROM (SELECT e.pid, rank, mins.mints-(ts+dur) AS delta
                  FROM events e
                  INNER JOIN names n ON e.eid = n.eid
                  INNER JOIN (
                      SELECT pid, min(ts+dur) AS mints
                      FROM events
                      INNER JOIN names ON events.eid = names.eid
                      WHERE name GLOB '*com.initializeIntraCom'
                      GROUP BY pid
                      ) AS mins ON mins.pid = e.pid
                  WHERE name GLOB '*com.initializeIntraCom'
                  AND delta != 0) AS mins
            WHERE events.pid == mins.pid AND events.rank == mins.rank
            """
    )

    print("Align participants")
    # Get primary rank event pairs of accept and request rank connections
    # We need to join via the acceptor participant name which is part of the request event name
    cur.execute(
        """
            SELECT acc.pid, accpart.name, acc.ts+acc.dur, req.pid, reqpart.name, req.ts+req.dur
            FROM events acc
            INNER JOIN names accname ON acc.eid = accname.eid
            INNER JOIN participants accpart ON acc.pid = accpart.pid
            INNER JOIN names reqname ON req.eid = reqname.eid
            INNER JOIN events req
            INNER JOIN participants reqpart ON req.pid = reqpart.pid
            WHERE accname.name GLOB '*m2n.acceptPrimaryRankConnection.*'
            AND reqname.name GLOB '*m2n.requestPrimaryRankConnection.' || accpart.name
            AND acc.rank = 0
            AND req.rank = 0
            """
    )
    # Align participant pairs one at a time for ease of debugging
    syncs = cur.fetchall()
    for accpid, accname, accts, reqpid, reqname, reqts in syncs:
        delta = accts - reqts
        print(f"Aligning {accname} with {reqname} shift latter by {delta}")
        cur.execute(
            """
                UPDATE events
                SET ts = ts + ?
                WHERE pid = ?
                """,
            (
                delta,
                reqpid,
            ),
        )


def groupEvents(
    cur: sqlite3.Cursor, pid: int, rank: int, events: [dict], initTime: int
):

    # Expands event names
    def namedEvents():
        nameMap = {int(e["eid"]): e["en"] for e in events if e["et"] == "n"}
        for e in events:
            type = e["et"]
            if type != "n":
                e["eid"] = nameMap[e["eid"]]
                if type == "d":
                    e["dn"] = nameMap[e["dn"]]
                yield e

    active = {}  # name to event data
    stack = []

    for event in namedEvents():
        type = event["et"]

        name: str = event["eid"]
        assert isinstance(name, str)

        # Handle event starts
        if type == "b":
            # assert(name not in active.keys())
            if name in active.keys():
                print(f"Ignoring start of active event {name}")
            else:
                event["ts"] = int(event["ts"])
                fullName = "/".join(stack + [name])
                event["eid"] = fullName
                active[name] = event
                if name != "_GLOBAL":
                    stack.append(name)
        # Handle event stops
        elif type == "e":
            # assert(name in active.keys())
            if name not in active.keys():
                print(f"Ignoring end of inactive event {name}")
            else:
                begin = active[name]
                active.pop(name)
                fullName = begin["eid"]
                dur = int(event["ts"]) - begin["ts"]
                ts = int(begin["ts"]) + initTime
                data = begin.get("data", "")
                insertEvent(cur, pid, rank, fullName, ts, dur, data)
                if name != "_GLOBAL":
                    assert (
                        stack[-1] == name
                    ), f"Expected to end event {name} but the currently active event is {stack[-1]}. Note that events need to follow a strict nesting and overlapping starts/stops are not permitted."
                    stack.pop()
        # Handle event data
        elif type == "d":
            if name not in active.keys():
                print(f"Dropping data event {name} as event isn't yet known.")
            else:
                d = active[name].get("data", {})
                dname = event["dn"]
                assert isinstance(dname, str)
                d[dname] = int(event["dv"])
                active[name]["data"] = d

    # Handle leftover events in case of a truncated input file
    if active:
        lastTS = min(map(lambda e: e["ts"] + e["dur"], completed))
        for event in active.values():
            name = event["eid"]  # This is a global id
            print(f"Truncating event without end {name}")
            begin = active[name]
            ts = int(begin["ts"]) + initTime
            dur = lastTS - begin["ts"]
            data = begin.get("data")
            insertEvent(cur, pid, rank, name, ts, dur, data)


def loadProfilingOutputs(con: sqlite3.Connection, filenames: list[pathlib.Path]):
    cur = createProfilingDB(con)

    # Load all jsons
    print(f"Loading {len(filenames)} event files")
    for fn in filenames:
        print(f"Loading {fn}")
        json = readJSON(fn) if fn.suffix == ".json" else readTXT(fn)

        # General checks
        if not json:
            warning(
                "The file is empty or was unable to be load and will be ignored.", fn
            )
            continue
        if "meta" not in json:
            warning("The file doesn't contain metadata and will be ignored.", fn)
            continue
        elif "events" not in json:
            warning("The file doesn't contain event data and will be ignored.", fn)
            continue
        else:
            version = json["meta"].get("file_version")
            if version is None:
                warning(
                    "The file doesn't contain a version (preCICE version v3.2 or earlier) and may be incompatible.",
                    fn,
                )
            elif version == 1:
                warning(
                    f"The file uses development version 1, upgrading to a newer preCICE version is highly recommended.",
                    fn,
                )
            elif version != RUN_FILE_VERSION:
                warning(
                    f"The file uses version {version}, which doesn't match the expected version {RUN_FILE_VERSION} and may be incompatible.",
                    fn,
                )

        # Grouping events
        name = json["meta"]["name"]
        rank = int(json["meta"]["rank"])
        size = int(json["meta"]["size"])

        pid = addOrFetchParticipant(cur, name, size)

        unix_us = int(json["meta"]["unix_us"])
        print(f"Processing {fn}")
        groupEvents(cur, pid, rank, json["events"], unix_us),
        del json


def detectFiles(files: list[pathlib.Path]):
    def searchDir(directory: pathlib.Path):
        assert directory.is_dir()
        import re

        nameMatcher = r".+-\d+-\d+.(json|txt)"
        return [
            candidate
            for pattern in ("**/*.json", "**/*.txt")
            for candidate in path.rglob(pattern)
            if re.fullmatch(nameMatcher, candidate.name)
        ]

    resolved = []
    for path in files:
        if path.is_file():
            resolved.append(path)
            continue
        if path.is_dir():
            detected = searchDir(path)
            if len(detected) == 0:
                print(f"Nothing found in {path}")
            else:
                print(f"Found {len(detected)} files in {path}")
                resolved += detected
        else:
            print(f'Cannot interpret "{path}"')

    unique = list(set(resolved))
    if len(files) > 1:
        print(f"Found {len(unique)} profiling files in total")
    return unique


def findFilesOfLatestRun(name, sizes):
    assert len(sizes) > 1
    print(f"Found multiple runs for participant {name}")
    timestamps = []
    for size, ranks in sizes.items():
        assert len(ranks) > 0
        example = next(iter(ranks.values()))  # Get some file of this run
        timestamp = readTimestamp(example)
        timestamps.append((size, timestamp))

    # Find oldest size of newest timestamps
    size, _ = max(timestamps, key=lambda p: p[1])
    print(f"`-selected latest run of size {size}")

    return list(sizes[size].values())


def groupRuns(files: list[pathlib.Path]):
    PieceFile = namedtuple("PieceFile", ["name", "rank", "size", "filename"])

    def info(filename: pathlib.Path):
        parts = filename.stem.split("-")
        name = "-".join(parts[:-2])
        return PieceFile(name, int(parts[-2]), int(parts[-1]), filename)

    pieces = [info(filename) for filename in files]

    map = {}
    for n, r, s, fn in pieces:
        rankMap = map.setdefault(n, {}).setdefault(s, {})
        if r in rankMap:
            existing = rankMap[r]
            if existing.suffix == ".txt":
                warning(
                    f"Ignored new .json due to conflict with existing .txt '{existing}'",
                    fn,
                )
            else:
                warning(
                    f"Newer .txt replaces previously found .json file '{existing}'", fn
                )
                rankMap.update({r: fn})
        else:
            rankMap.update({r: fn})

    return map


def sanitizeFiles(files: list[pathlib.Path]):
    map = groupRuns(files)
    filesToLoad = []
    for name, sizes in map.items():
        if len(sizes) == 1:
            print(f"Found a single run for participant {name}")
            filesToLoad += [
                filename for _, ranks in sizes.items() for filename in ranks.values()
            ]
            continue

        filesToLoad = findFilesOfLatestRun(name, sizes)
    return filesToLoad


def runMerge(ns):
    return mergeCommand(ns.files, ns.output, not ns.no_align)


def mergeCommand(files, outfile, align):
    resolved = detectFiles(files)
    sanitized = sanitizeFiles(resolved)

    outfile.unlink(missing_ok=True)
    con = sqlite3.connect(outfile)

    loadProfilingOutputs(con, sanitized)

    if align:
        alignEvents(con)

    # commit and tidy up
    con.commit()
    con.execute("VACUUM")
    con.close()

    return 0


def makeMergeParser(add_help: bool = True):
    merge_help = "Merges preCICE profiling output files to a single file used by the other commands."
    merge = argparse.ArgumentParser(description=merge_help, add_help=add_help)
    merge.add_argument(
        "files",
        type=pathlib.Path,
        nargs="*",
        help="The profiling files to process, directories to search, or nothing to autodetect",
        default=[pathlib.Path(".")],
    )
    merge.add_argument(
        "-o",
        "--output",
        type=pathlib.Path,
        default="profiling.db",
        help="The resulting profiling file.",
    )
    merge.add_argument(
        "-n", "--no-align", action="store_true", help="Don't align participants?"
    )
    return merge


def main():
    print(
        "This script is deprecated, will no longer receive upgrades, only provides the merge command, and will be removed in preCICE release 4.\n"
        "Please migrate to the precice-cli (pipx install precice-cli) or the new profiling scripts repository (pipx install precice-profiling).\n"
        "More information at https://precice.org/tooling-overview.html",
        file=sys.stderr,
    )
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(title="commands", dest="cmd", required=True)
    subparsers.add_parser("merge", parents=[makeMergeParser(add_help=False)])
    ns = parser.parse_args()
    return runMerge(ns)


if __name__ == "__main__":
    sys.exit(main())
