#!/usr/bin/python3 -u
# SPDX-FileCopyrightText: 2024-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

"""UCS Testrunner - run UCS test in sane environment."""


import argparse
import logging
import operator
import sys
from collections.abc import Iterable, Mapping
from time import time

import univention.testing.format
from univention.testing.codes import Reason
from univention.testing.coverage import Coverage
from univention.testing.data import TestCase, TestEnvironment, TestFormatInterface, TestResult
from univention.testing.errors import TestError
from univention.testing.internal import LOG_BASE, get_sections, get_tests, setup_debug, setup_environment
from univention.testing.pytest import PytestRunner


class StoreTag(argparse.Action):
    tag_level = ""

    def __call__(self, parser, namespace, value, option_string=None):
        dest = getattr(namespace, self.dest, {})
        dest[value] = self.tag_level
        setattr(namespace, self.dest, dest)


class StoreRequired(StoreTag):
    tag_level = "required"


class StoreProhibited(StoreTag):
    tag_level = "prohibited"


class StoreIgnored(StoreTag):
    tag_level = "ignored"


def parse_options(sections: Iterable[str]) -> tuple[argparse.Namespace, list[str]]:
    """Parse command line options."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-H", "--hold",
        action="store_true",
        help="Stop execution after the first failed test")
    parser.add_argument(
        "-t", "--timeout",
        default=3600, type=int,
        help="Abort single test after [%(default)s]s")

    selection_group = parser.add_argument_group('Test selection')
    selection_group.add_argument(
        "-s", "--section",
        dest="sections", action="append", choices=sections,
        help="Run tests only from this section", metavar="SECTION")
    selection_group.add_argument(
        "-p", "--prohibit",
        dest="tags", action=StoreProhibited, default={},
        help="Skip tests with this tag", metavar="TAG")
    selection_group.add_argument(
        "-r", "--require",
        dest="tags", action=StoreRequired, default={},
        help="Only run tests with this tag", metavar="TAG")
    selection_group.add_argument(
        "-g", "--ignore",
        dest="tags", action=StoreIgnored, default={},
        help="Neither require nor prohibit this tag", metavar="TAG")
    selection_group.add_argument(
        "-E", "--exposure",
        choices=('safe', 'careful', 'dangerous'),
        help="Run more dangerous tests")

    output_group = parser.add_argument_group('Output options')
    output_group.add_argument(
        "-n", "--dry-run",
        dest="dry", action="store_true",
        help="Only show which tests would run")
    output_group.add_argument(
        "-f", "--filtered",
        dest="filter", action="store_true",
        help="Hide tests with unmatched pre-conditions")
    output_group.add_argument(
        "-F", "--format",
        choices=univention.testing.format.FORMATS, default='text',
        help="Select output format [%(default)s]")
    output_group.add_argument(
        "-v", "--verbose",
        action="count",
        help="Increase verbosity")
    output_group.add_argument(
        "-i", "--interactive",
        action="store_true",
        help="Run test connected to terminal")
    output_group.add_argument(
        "-c", "--count",
        action="store_true",
        help="Prefix tests by count")
    output_group.add_argument(
        "-l", "--logfile",
        default=LOG_BASE % (time(),),
        help="Path to log file [%(default)s]")

    PytestRunner.get_argument_group(parser)
    Coverage.get_argument_group(parser)
    return parser.parse_known_args()


class TestSet:
    """Container for tests."""

    def __init__(self, tests: Mapping[str, Iterable[str]]) -> None:
        self.tests = tests
        self.max_count = sum(1 for tests in tests.values() for test in tests)
        self.test_environment: TestEnvironment | None = None
        self.format: TestFormatInterface | None = None
        self.prefix = ''

    def set_environment(self, test_environment: TestEnvironment) -> None:
        """Set environment for running tests."""
        self.test_environment = test_environment

    def set_format(self, format: str) -> None:
        """Select output format."""
        formatter = getattr(univention.testing.format, f'format_{format}')
        self.format = formatter()

    def set_prefix(self, prefix: bool) -> None:
        """Enable or disable test numbering."""
        if prefix:
            count_width = len('%d' % (self.max_count,))
            self.prefix = '%%0%dd/%%0%dd ' % (count_width, count_width)
        else:
            self.prefix = ''

    def run_tests(self, filter_condition: bool = False, dry_run: bool = False, stop_on_failure: bool = False) -> int | None:
        """Run selected tests."""
        assert self.format
        assert self.test_environment
        self.format.begin_run(self.test_environment, self.max_count)
        try:
            count = 0
            for section, cases in sorted(self.tests.items(), key=operator.itemgetter(1)):
                self.format.begin_section(section)
                try:
                    for fname in cases:
                        count += 1
                        test_case = TestCase(fname)
                        test_result = TestResult(test_case, self.test_environment)
                        try:
                            test_case.load()
                        except TestError as ex:
                            logger = logging.getLogger('test')
                            failed_message = f'Failed to load test "{fname}": {ex}'
                            logger.critical(failed_message)
                            check = True
                        else:
                            failed_message = ""
                            check = test_result.check()

                        if filter_condition and not check:
                            continue

                        if section == 's4connector':
                            from univention.testing.utils import wait_for_s4_connector_to_be_inactive
                            wait_for_s4_connector_to_be_inactive()

                        if self.prefix:
                            self.format.begin_test(test_case, self.prefix % (count, self.max_count))
                        else:
                            self.format.begin_test(test_case)

                        try:
                            if failed_message:
                                test_result.reason = Reason.FAIL
                                test_result.attach('stdout', 'text/plain', failed_message)
                            elif not dry_run:
                                test_result.run()
                                if stop_on_failure and test_result.reason.eofs in 'EF':
                                    return 1
                        finally:
                            self.format.end_test(test_result)
                finally:
                    self.format.end_section()
        finally:
            self.format.end_run()
        return 0


def main() -> int | None:
    """Run UCS test suite."""
    all_sections = get_sections()

    (options, args) = parse_options(all_sections.keys())
    if args:
        logger = logging.getLogger('test')
        logger.error('Unused arguments: %r', args)
        sys.exit(2)

    setup_environment()
    setup_debug(options.verbose)

    selected_sections = options.sections or all_sections.keys()
    tests = get_tests(selected_sections)

    coverage = Coverage(options)
    coverage.start()
    PytestRunner.set_arguments(options)
    test_set = TestSet(tests)

    if options.dry:
        test_environment = TestEnvironment(interactive=options.interactive)
    else:
        test_environment = TestEnvironment(
            interactive=options.interactive,
            logfile=options.logfile)
    tags_required = [tag for tag, level in options.tags.items() if level == "required"]
    tags_ignored = [tag for tag, level in options.tags.items() if level == "ignored"]
    tags_prohibited = [tag for tag, level in options.tags.items() if level == "prohibited"]
    test_environment.tag(
        require=tags_required,
        ignore=tags_ignored,
        prohibit=tags_prohibited)
    if options.exposure:
        test_environment.set_exposure(options.exposure)
    test_environment.set_timeout(options.timeout)
    test_set.set_environment(test_environment)
    test_set.set_prefix(options.count)
    test_set.set_format(options.format)
    try:
        return test_set.run_tests(options.filter, options.dry, options.hold)
    finally:
        coverage.stop()


if __name__ == '__main__':
    try:
        sys.exit(main())
    except KeyboardInterrupt:
        sys.exit(1)
