Source code for clack._main

"""Defines the clack.main_factory() function."""

from __future__ import annotations

from pathlib import Path
import signal
import sys
from typing import (
    Any,
    Callable,
    Final,
    Iterable,
    List,
    Optional,
    Sequence,
    Type,
    cast,
    get_type_hints,
    overload,
)

from logrus import Log, Logger, init_logging
from typist import literal_to_list

from . import _dynvars as dyn
from ._helpers import filter_cli_args
from .types import ClackConfig, ClackMain, ClackParser, ClackRunner


ASSERT_MAIN_FACTORY_PRECOND: Final = (
    "EXACTLY ONE of the following MUST be true when calling the"
    " clack.main_factory() function: (1) The 'run' positional argument is"
    " provided OR (2) the 'runners' and 'parser' keyword-only arguments are"
    " BOTH provided."
)
ASSERT_MAIN_RUNNERS_PRECOND: Final = (
    "ALL runners must have a clack.Config 'cfg' positional argument which"
    ' inherit from the SAME shared clack.Config subclass. This "shared'
    " Config\" should have a 'command' attribute that is typed as a"
    " typing.Literal, say CommandLiteral. Additionally, all subclasses of this"
    " shared Config MUST have a 'command' attribute that is typed as a"
    " typing.Literal which is a a subtype of CommandLiteral"
)


@overload
def main_factory(  # noqa: E704
    app_name: str, run: ClackRunner
) -> ClackMain: ...


@overload
def main_factory(  # noqa: E704
    app_name: str, *, runners: Iterable[ClackRunner], parser: ClackParser
) -> ClackMain: ...


[docs] def main_factory( app_name: str, run: ClackRunner = None, *, runners: Iterable[ClackRunner] = None, parser: ClackParser = None, ) -> ClackMain: """Factory used to create a new `main()` function. Returns: A generic main() function to be used as a script's entry point. """ run_is_set = bool(run is not None) kwargs_only = bool(runners is not None and parser is not None) assert run_is_set or kwargs_only, ASSERT_MAIN_FACTORY_PRECOND assert not (run_is_set and kwargs_only), ASSERT_MAIN_FACTORY_PRECOND def main_run(argv: Sequence[str]) -> int: assert run is not None config_file = _get_config_file_from_argv(argv) config_type = _get_run_cfg(run) with dyn.clack_envvars_set( app_name, [config_type], config_file=config_file ): cfg = config_type.from_cli_args(argv) return do_main_work(run, cfg) def main_runners(argv: Sequence[str]) -> int: assert runners is not None assert parser is not None runner_list = list(runners) all_config_types = _get_all_config_types(runner_list) config_file = _get_config_file_from_argv(argv) with dyn.clack_envvars_set( app_name, all_config_types, config_file=config_file ): parser_kwargs = parser(argv) config_type = _config_type_from_command( all_config_types, parser_kwargs["command"] ) filtered_kwargs = filter_cli_args(parser_kwargs) cfg = config_type(**filtered_kwargs) run = _main_runner_factory(runners) return do_main_work(run, cfg) def do_main_work(runner: ClackRunner, cfg: ClackConfig) -> int: verbose: int = getattr(cfg, "verbose", 0) logs: List[Log] = getattr(cfg, "logs", []) init_logging(logs=logs, verbose=verbose) logger = Logger("clack", app_name=app_name, cfg=cfg) # The following log messages will obviously only be visible if the # corresponding log level really is enabled, but stating the obvious in # this case seemed like the right thing to do so ¯\_(ツ)_/¯. logger.trace("TRACE level logging enabled.") logger.debug("DEBUG level logging enabled.") try: with dyn.clack_envvars_set(app_name, [type(cfg)], cfg=cfg): status = runner(cfg) except KeyboardInterrupt: # pragma: no cover logger.info("Received SIGINT signal. Terminating script...") return 128 + signal.SIGINT.value except Exception: # pragma: no cover logger.exception( "An unrecoverable error has been raised. Terminating script..." ) return 1 else: return status def wrap_main(outer_main: Callable[[Sequence[str]], int]) -> ClackMain: def inner_main(argv: Sequence[str] = None) -> int: if argv is None: # pragma: no cover argv = sys.argv # We first initialize logging here with no config, so we can log # messages in the clack parser. verbose = 0 for opt_or_arg in argv: if opt_or_arg.startswith("-v"): verbose = opt_or_arg.count("v") break init_logging(verbose=verbose) return outer_main(argv) return inner_main if run is None: return wrap_main(main_runners) else: return wrap_main(main_run)
def _get_config_file_from_argv(argv: Sequence[str]) -> Optional[Path]: for opt in ["-c", "--config"]: for idx, argv_opt in enumerate(argv): if opt == argv_opt: return Path(argv[idx + 1]) opt_prefix = opt if opt_prefix.startswith("--"): opt_prefix += "=" if argv_opt.startswith(opt_prefix): cfg_fname = argv_opt[len(opt_prefix) :] return Path(cfg_fname) return None def _main_runner_factory(runners: Iterable[ClackRunner]) -> ClackRunner: def run(cfg: Any) -> int: for run in runners: run_config_type = _get_run_cfg(run) if isinstance(cfg, run_config_type): return run(cfg) return -1 return run def _get_run_cfg(run: ClackRunner) -> Type[ClackConfig]: run_hints = get_type_hints(run) try: cfg: Type[ClackConfig] = run_hints["cfg"] return cfg except KeyError as e: raise AssertionError( "Logic Error! Every runner function should have a 'cfg' kwarg!" ) from e def _config_type_from_command( all_config_types: Iterable[Type[ClackConfig]], choosen_command: str, ) -> Type[ClackConfig]: config_type_to_command = {} for some_config_type in all_config_types: some_command = _get_single_command(some_config_type) config_type_to_command[some_config_type] = some_command if some_command == choosen_command: return some_config_type raise AssertionError( "Logic Error! None of the given Config types seem to match the" f" choosen sub-command. | {ASSERT_MAIN_RUNNERS_PRECOND} |" f" choosen_command={choosen_command}" f" config_type_to_command={config_type_to_command!r}" ) def _get_all_commands(config_type: Type[ClackConfig]) -> List[str]: config_type_hints = get_type_hints(config_type) try: command_type = config_type_hints["command"] except KeyError as e: raise AssertionError( "Logic Error! When using sub-commands in your CLI interface with" " clack, ALL Config classes MUST have a 'command' attribute!" ) from e result = cast(List[str], literal_to_list(command_type)) return result def _get_single_command(config_type: Type[ClackConfig]) -> str: all_commands = _get_all_commands(config_type) assert len(all_commands) == 1, ( ASSERT_MAIN_RUNNERS_PRECOND + f" | {all_commands!r}" ) return all_commands[0] def _get_all_config_types( runners: Iterable[ClackRunner], ) -> List[Type[ClackConfig]]: result: List[Type[ClackConfig]] = [] for runner in runners: config_type = _get_run_cfg(runner) result.append(config_type) return result