#!/usr/bin/python3
# SPDX-FileCopyrightText: 2024-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only
"""Find errors in tests/"""

import os
import re
import sys


RE_BASH = re.compile(r'''^#! */usr/share/ucs-test/runner +bash\b''')
RE_FUNCTION = re.compile(r'''^ *(?:function +)?([-0-9A-Z_a-z]+) *\(\)''')
RE_SOURCE = re.compile(r'''^ *(?:source|\.) +((?:"(?:[^"\\]|\\.)+"|'[^']+'|\S+)+)\b''')


class Lint:
    D_LIB = "lib/"
    D_TEST = "tests/"

    def __init__(self):
        self.functions = {}
        self.dependencies = {}
        self.regexp = None

    def main(self):
        self.parse_libs()
        self.compile_regexp()
        try:
            self.process_test(sys.argv[1])
        except IndexError:
            self.parse_tests()

    def parse_libs(self):
        for root, _dirs, files in os.walk(self.D_LIB):
            for name in files:
                if not name.endswith('.sh'):
                    continue
                path = os.path.join(root, name)
                self.process_lib(path)

    def process_lib(self, path):
        libs = self.dependencies[path] = set()
        with open(path) as lib:
            for line in lib:
                match = RE_FUNCTION.match(line)
                if match:
                    func = match.group(1)
                    if func in self.functions:
                        print(f"'{func}' re-defined from '{self.functions[func]}' in '{path}'")
                    else:
                        self.functions[func] = path
                    continue

                match = RE_SOURCE.match(line)
                if match:
                    lib = match.group(1)
                    lib = lib.replace("$TESTLIBPATH/", self.D_LIB)
                    lib = lib.strip('"')
                    libs.add(lib)
                    continue

    def compile_regexp(self):
        # Remove functions named too generic
        del self.functions['ucr']
        del self.functions['error']
        del self.functions['info']
        self.regexp = re.compile(
            r'\b(?:%s)(?=$|\s)' % (
                '|'.join(re.escape(func) for func in self.functions),
            ),
        )

    def parse_tests(self):
        for root, _dirs, files in os.walk(self.D_TEST):
            for name in files:
                if '.' in name:
                    continue
                path = os.path.join(root, name)
                self.process_test(path)

    def process_test(self, path):
        libs = set()
        with open(path) as test:
            for line in test:
                if not RE_BASH.match(line):
                    return
                break

            for nr, line in enumerate(test, start=2):
                line = line.strip()
                match = RE_SOURCE.match(line)
                if match:
                    lib = match.group(1)
                    lib = lib.replace("$TESTLIBPATH/", self.D_LIB)
                    lib = lib.strip('"')
                    libs.add(lib)
                    libs.update(self.dependencies.get(lib, ()))
                    continue

                match = self.regexp.search(line)
                if match:
                    func = match.group(0)
                    lib = self.functions[func]
                    if lib not in libs:
                        print('%s:%d:%s %s' % (path, nr, func, line))


if __name__ == '__main__':
    Lint().main()
