# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import argparse
import glob
import os
import re
import socket
import sys
import textwrap
from enum import Enum

from nsys_recipe.lib import recipe
from nsys_recipe.log import logger


class Option(Enum):
    """Common recipe options"""

    OUTPUT = 0
    FORCE_OVERWRITE = 1
    ROWS = 2
    START = 3
    END = 4
    NVTX = 5
    BASE = 6
    MANGLED = 7
    BINS = 8
    CSV = 9
    INPUT = 10
    DISABLE_ALIGNMENT = 11
    FILTER_NVTX = 12
    FILTER_TIME = 13
    FILTER_PROJECTED_NVTX = 14
    HIDE_INACTIVE = 15
    PER_GPU = 16
    PER_STREAM = 17


def _replace_range(name, start_index, end_index, value):
    return name[:start_index] + str(value) + name[end_index + 1 :]


def _substitute_env_var(name, index):
    pos = index + 1
    if pos >= len(name) or name[pos] != "{":
        logger.error("Missing '{' token after '%q' expression.")
        return name, len(name)

    pos += 1
    end = name.find("}", pos)
    if end == -1:
        logger.error("Missing '}' token after '%q{' expression.")
        return name, len(name)

    env_var = name[pos:end]
    value = os.getenv(env_var)

    if value is None:
        logger.warning(f"Environment variable '{env_var}' is not set.")
        return name, end + 1

    start = index - 1
    return _replace_range(name, start, end, value), start + len(value)


def _substitute_hostname(name, index):
    try:
        hostname = socket.gethostname()
    except socket.error as e:
        logger.warning(f"Unable to get host name: {e}")
        hostname = ""

    start = index - 1
    end = start + 1
    return _replace_range(name, start, end, hostname), start + len(hostname)


def _substitute_pid(name, index):
    pid = os.getpid()
    start = index - 1
    end = start + 1
    return _replace_range(name, start, end, pid), start + len(str(pid))


def _substitute_counters(name, counter_indices):
    if not counter_indices:
        return name

    orig_name = name
    for num in range(1, sys.maxsize):
        name = orig_name
        for index in reversed(counter_indices):
            name = _replace_range(name, index, index + 1, num)
        if not os.path.exists(name):
            return name
        num += 1

    raise ValueError("Maximum limit reached. Unable to find an available output name.")


def process_output(name):
    counter_indices = []

    index = name.find("%")
    while index != -1:
        index += 1
        if index >= len(name):
            logger.error("Unterminated " % " expression.")
            return name

        token = name[index]
        if token == "q":
            name, index = _substitute_env_var(name, index)
        elif token == "h":
            name, index = _substitute_hostname(name, index)
        elif token == "p":
            name, index = _substitute_pid(name, index)
        elif token == "n":
            counter_indices.append(index - 1)
        elif token == "%":
            name = _replace_range(name, index, index, "")
        else:
            logger.error(f"Unknown expression '%{token}'.")

        index = name.find("%", index)

    return _substitute_counters(name, counter_indices)


def process_directory(report_dir):
    files = []

    report_dir = os.path.abspath(report_dir)
    for ext in ("*.nsys-rep", "*.qdrep"):
        files.extend(glob.glob(os.path.join(report_dir, ext)))

    if not files:
        raise argparse.ArgumentTypeError("No nsys-rep files found.")

    return files


def process_input(path):
    extensions = (".qdrep", ".nsys-rep")
    n = None

    if ":" in path and not os.path.exists(path):
        path, n = path.rsplit(":", 1)
        try:
            n = int(n)
        except ValueError as e:
            raise argparse.ArgumentTypeError(
                "Expecting an integer value after the colon in the path."
            ) from e

    if os.path.isfile(path):
        if n is not None:
            raise argparse.ArgumentTypeError(
                "The ':n' syntax cannot be used for files."
            )
        if not path.endswith(extensions):
            raise argparse.ArgumentTypeError(f"{path} is not a nsys-rep file.")
        return path

    if os.path.isdir(path):
        files = sorted(
            file
            for extension in extensions
            for file in glob.glob(os.path.join(path, f"*{extension}"))
        )
        if not files:
            raise argparse.ArgumentTypeError(f"{path} does not contain nsys-rep files.")
        return files[:n]

    raise argparse.ArgumentTypeError(f"{path} does not exist.")


def process_integer(min_value):
    """Type function for argparse
    Returns a function that takes only a string argument and checks if the provided argument is
    greater than the min_value. Otherwise it raises an exception.
    The reason for this structure of functions is 'The argument to type can be any callable that
    accepts a single string.'"""

    def type_function(number_str):
        try:
            number = int(number_str)
        except ValueError:
            raise argparse.ArgumentTypeError("The argument must be an integer number")
        if number < min_value:
            raise argparse.ArgumentTypeError(
                f"The argument must be greater or equal to {min_value}"
            )
        return number

    return type_function


class TextHelpFormatter(argparse.HelpFormatter):
    """This class is similar to argparse.RawDescriptionHelpFormatter, but
    retains line breaks when formatting the help message."""

    def _fill_text(self, text, width, indent=""):
        lines = text.splitlines()
        a = [
            textwrap.fill(line, width, initial_indent=indent, subsequent_indent=indent)
            for line in lines
        ]
        return "\n".join(a)

    def _split_lines(self, text, width):
        return self._fill_text(text, width).split("\n")


class ModeAction(argparse.Action):
    def __init__(self, **kwargs):
        kwargs.setdefault(
            "choices",
            tuple(mode.name.replace("_", "-").lower() for mode in recipe.Mode),
        )
        super().__init__(**kwargs)

    def __call__(self, parser, namespace, values, option_string=None):
        value = recipe.Mode[values.replace("-", "_").upper()]
        setattr(namespace, self.dest, value)


class InputAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        # Remove any inner lists.
        flattened_list = []
        for value in values:
            if isinstance(value, list):
                flattened_list.extend(value)
            else:
                flattened_list.append(value)

        setattr(namespace, self.dest, flattened_list)


class NvtxFilterAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        # Split the composite value of range[@domain][/index] into its
        # components.
        before_at, after_at = self._find_seperator(values, "@")
        before_slash, after_slash = self._find_seperator(
            after_at if after_at else before_at, "/"
        )

        if not after_at:
            range_name = before_slash
            domain_name = None
        else:
            range_name = before_at
            domain_name = before_slash

        try:
            index = int(after_slash) if after_slash else None
        except ValueError as e:
            raise argparse.ArgumentError(
                self, "Expecting an integer value for the index."
            )

        setattr(namespace, self.dest, (range_name, domain_name, index))

    def _unescape(self, value, separator):
        return value.replace(f"\\{separator}", f"{separator}")

    def _find_seperator(self, value, separator):
        matches = re.finditer(r"(?<!\\)" + separator, value)
        positions = [match.start() for match in matches]

        if not positions:
            return value, ""

        if len(positions) > 1:
            raise argparse.ArgumentError(
                self,
                f"{self.dest} accepts only one '{separator}'."
                f" Any '{separator}' in the names should be escaped with a backslash.",
            )

        position = positions[0]
        before_val = self._unescape(value[:position], separator)
        after_val = self._unescape(value[position + 1 :], separator)

        return before_val, after_val


class TimeFilterAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        split_values = values.split("/")

        try:
            start_time, end_time = (
                int(value) if value else None for value in split_values
            )
        except:
            raise argparse.ArgumentError(
                self,
                "Time range must be in the format [start_time]/[end_time] with integer values.",
            )

        if start_time is None and end_time is None:
            raise argparse.ArgumentError(
                self,
                "At least one of the start_time or end_time must be provided.",
            )

        if start_time is not None and end_time is not None and start_time >= end_time:
            raise argparse.ArgumentError(self, "start_time must be less than end_time.")

        setattr(namespace, self.dest, (start_time, end_time))


class ArgumentParser(argparse.ArgumentParser):
    """Custom argument parser with predefined arguments"""

    def __init__(self, **kwargs):
        kwargs["add_help"] = False
        super().__init__(**kwargs)

        self._context_group = self.add_argument_group("Context")
        self._recipe_group = self.add_argument_group("Recipe")
        # We add the help message manually to match the format of the other
        # messages.
        self.add_argument(
            "-h", "--help", action="help", help="Show this help message and exit."
        )

    @property
    def recipe_group(self):
        return self._recipe_group

    def add_recipe_argument(self, option, *args, **kwargs):
        self.add_argument_to_group(self._recipe_group, option, *args, **kwargs)

    def add_argument_to_group(self, group, option, *args, **kwargs):
        if not isinstance(option, Option):
            group.add_argument(option, *args, **kwargs)
            return

        if option == Option.OUTPUT:
            group.add_argument(
                "--output",
                type=process_output,
                help="Output directory name.\n"
                "Any %%q{ENV_VAR} pattern in the filename will be substituted with the value of the environment variable.\n"
                "Any %%h pattern in the filename will be substituted with the hostname of the system.\n"
                "Any %%p pattern in the filename will be substituted with the PID.\n"
                "Any %%n pattern in the filename will be substituted with the minimal positive integer that is not already occupied.\n"
                "Any %%%% pattern in the filename will be substituted with %%.",
                **kwargs,
            )
        elif option == Option.FORCE_OVERWRITE:
            group.add_argument(
                "--force-overwrite",
                action="store_true",
                help="Overwrite existing directory.",
                **kwargs,
            )
        elif option == Option.ROWS:
            group.add_argument(
                "--rows",
                metavar="limit",
                type=int,
                default=-1,
                help="Maximum number of rows per input file.",
                **kwargs,
            )
        elif option == Option.START:
            group.add_argument(
                "--start",
                metavar="time",
                type=int,
                help=argparse.SUPPRESS,
                **kwargs,
            )
        elif option == Option.END:
            group.add_argument(
                "--end",
                metavar="time",
                type=int,
                help=argparse.SUPPRESS,
                **kwargs,
            )
        elif option == Option.FILTER_TIME:
            group.add_argument(
                "--filter-time",
                metavar="[start_time]/[end_time]",
                action=TimeFilterAction,
                help="Filter by time range in nanoseconds.",
                **kwargs,
            )
        # This is the NVTX option used by the nsysstats module for Stats and
        # Expert System. It behaves differently from the "--filter-nvtx" option
        # used by the recipe scripts.
        elif option == Option.NVTX:
            group.add_argument(
                "--nvtx",
                metavar="range[@domain]",
                type=str,
                help="Filter by NVTX range.",
                **kwargs,
            )
        elif option == Option.FILTER_NVTX:
            group.add_argument(
                "--filter-nvtx",
                metavar="range[@domain][/index]",
                type=str,
                action=NvtxFilterAction,
                help="Filter by NVTX range using only the start and end times of the matching ranges.\n"
                "Specify the domain only when the range is not in the default domain, or use '*' to include all domains."
                " Any '@' or '/' in the names should be escaped with a backslash.\n"
                "The index is zero-based and is used to select the nth range."
                " If no index is specified, all ranges will be used.",
                **kwargs,
            )
        elif option == Option.FILTER_PROJECTED_NVTX:
            group.add_argument(
                "--filter-projected-nvtx",
                metavar="range[@domain][/index]",
                type=str,
                action=NvtxFilterAction,
                help="Filter by projected NVTX range using only the start and end times of the matching ranges.\n"
                "Specify the domain only when the range is not in the default domain, or use '*' to include all domains."
                " Any '@' or '/' in the names should be escaped with a backslash.\n"
                "The index is zero-based and is used to select the nth range."
                " If no index is specified, all ranges will be used.",
                **kwargs,
            )
        elif option == Option.BASE:
            group.add_argument(
                "--base", action="store_true", help="Use kernel base name.", **kwargs
            )
        elif option == Option.MANGLED:
            group.add_argument(
                "--mangled",
                action="store_true",
                help="Use kernel mangled name.",
                **kwargs,
            )
        elif option == Option.BINS:
            group.add_argument(
                "--bins",
                type=process_integer(0),
                default=30,
                help="Number of bins (default: %(default)s).",
                **kwargs,
            )
        elif option == Option.CSV:
            group.add_argument(
                "--csv",
                action="store_true",
                help="Additionally output data as CSV.",
                **kwargs,
            )
        elif option == Option.INPUT:
            group.add_argument(
                "--input",
                type=process_input,
                default=None,
                nargs="+",
                action=InputAction,
                help="One or more paths to nsys-rep files or directories.\n"
                "Directories can optionally be followed by ':n' to limit the number of files.",
                **kwargs,
            )
        elif option == Option.DISABLE_ALIGNMENT:
            group.add_argument(
                "--disable-alignment",
                action="store_true",
                help="Disable automatic session alignment.\n"
                "By default, session times are aligned based on the epoch time of the report file collection.\n"
                "This option will instead use relative time, which is useful for comparing individual sessions.",
                **kwargs,
            )
        elif option == Option.HIDE_INACTIVE:
            group.add_argument(
                "--hide-inactive",
                action="store_true",
                help="Exclude devices with zero traffic from the results.\n"
                "By default, all devices are shown.",
                **kwargs,
            )
        elif option == Option.PER_GPU:
            group.add_argument(
                "--per-gpu",
                action="store_const",
                const=["deviceId"],
                default=[],
                help="Group events by GPU.",
            )
        elif option == Option.PER_STREAM:
            group.add_argument(
                "--per-stream",
                action="store_const",
                const=["deviceId", "streamId"],
                default=[],
                help="Group events by stream within each GPU.",
            )
        else:
            raise NotImplementedError("Invalid option.")

    def add_context_arguments(self):
        self._context_group.add_argument(
            "--mode",
            action=ModeAction,
            default=recipe.Mode.CONCURRENT,
            help="Recipe execution mode:\n"
            "  - none: Sequential execution.\n"
            "  - concurrent: Parallel execution.\n"
            "  - dask-futures: Distributed execution.",
        )

    def parse_args(self, *args, **kwargs):
        parsed_args = super().parse_args(*args, **kwargs)

        start = getattr(parsed_args, "start", None)
        end = getattr(parsed_args, "end", None)
        filter_time = getattr(parsed_args, "filter_time", None)

        if start is not None:
            logger.warning(
                f"The '--start' option is deprecated and will be removed in a future version."
                " Please use the '--filter_time' option instead."
            )

        if end is not None:
            logger.warning(
                f"The '--end' option is deprecated and will be removed in a future version."
                " Please use the '--filter_time' option instead."
            )

        if filter_time is None and (start is not None or end is not None):
            parsed_args.filter_time = (start, end)
        elif filter_time is not None:
            # This is necessary for Stats and the Expert System, which receive
            # "start" and "end" as options for time filtering. Once the "start"
            # and "end" options are removed, we can move the following lines to
            # TimeFilterAction.
            parsed_args.start, parsed_args.end = parsed_args.filter_time

        return parsed_args
