Source code for clack._parser

"""Defines the clack.Parser() helper function."""

from __future__ import annotations

import argparse
from importlib.metadata import (
    PackageNotFoundError,
    distribution as get_dist_from_name,
    distributions as get_all_dists,
    version as get_version,
)
import inspect
import os
from pathlib import Path
import re
import sys
from types import ModuleType
from typing import (
    Any,
    Callable,
    Iterable,
    List,
    Mapping,
    Optional,
    Sequence,
    cast,
)

from logrus import Log, LogFormat, Logger, LogLevel, get_default_logfile
from typist import literal_to_list

from . import _dynvars as dyn
from ._config_file import YAMLConfigFile


ARGPARSE_ARGUMENT_DEFAULT = object()

logger = Logger(__name__)


[docs] def Parser(*args: Any, **kwargs: Any) -> argparse.ArgumentParser: """Wrapper for argparse.ArgumentParser.""" app_name = dyn.get_app_name() stack = list(inspect.stack()) stack.pop(0) frame = stack.pop(0).frame if kwargs.get("description") is None: try: kwargs["description"] = frame.f_globals["__doc__"] except KeyError: pass if kwargs.get("formatter_class") is None: kwargs["formatter_class"] = _HelpFormatter valid_log_levels = sorted(cast(List[str], literal_to_list(LogLevel))) valid_log_formats = sorted(cast(List[str], literal_to_list(LogFormat))) parser = argparse.ArgumentParser(*args, **kwargs) monkey_patch_parser(parser) parser.add_argument( "-c", "--config", dest="config_file", type=YAMLConfigFile, help=( "Absolute or relative path to a YAML file that contains this" " application's configuration." ), ) parser.add_argument( "-L", "--log", metavar="FILE[:LEVEL][@FORMAT]", dest="logs", action="append", nargs="?", const="+", type=_log_type_factory(app_name), help=( "This option can be used to enable a new logging handler. FILE" " should be either a path to a logfile or one of the following" " special file types: [1] 'stderr' to log to standard error" " (enabled by default), [2] 'stdout' to log to standard out, [3]" " 'null' to disable all console (e.g. stderr) handlers, or [4]" " '+[NAME]' to choose a default logfile path (where NAME is an" " optional basename for the logfile). LEVEL can be any valid log" f" level (i.e. one of {valid_log_levels}) and FORMAT can be any" f" valid log format (i.e. one of {valid_log_formats}). NOTE: This" " option can be specified multiple times and has a default" " argument of %(const)r." ), ) parser.add_argument( "-v", "--verbose", action="count", help=( "How verbose should the output be? This option can be specified" " multiple times (e.g. -v, -vv, -vvv, ...)." ), ) caller_mod = inspect.getmodule(frame) caller_dist_name = _get_dist_name_from_mod(caller_mod) caller_file = getattr(caller_mod, "__file__", None) if caller_dist_name and caller_file: try: package_version = get_version(caller_dist_name) version = f"{caller_dist_name} {package_version}" package_location = _get_package_location( caller_file, caller_dist_name ) version += f"\n from {package_location}" while stack: exe_fname = stack.pop(0).filename if os.access(exe_fname, os.X_OK): version += f"\n by {_shorten_homedir(exe_fname)}" break pyversion = ".".join(str(x) for x in sys.version_info[:3]) version += f"\n using Python {pyversion}" clack_dist_name = _get_dist_name_from_pkg_name(__package__) assert clack_dist_name is not None clack_version = get_version(clack_dist_name) version += f"\n{clack_dist_name} {clack_version}" clack_location = _get_package_location(__file__, __package__) version += f"\n from {clack_location}" parser.add_argument("--version", action="version", version=version) except PackageNotFoundError: pass return parser
def _get_dist_name_from_mod(from_mod: ModuleType | None) -> str | None: """ Retrieves the PyPI distribution name associated with a given module. Args: from_mod: The module object for which to find the distribution. Returns: The PyPI distribution name, or None if not found. """ if from_mod is None: logger.warn( "Aborting distribution name search since from module is None." ) return None assert from_mod.__package__ is not None dist_name = _get_dist_name_from_mod_and_pkg_name( from_mod.__name__, from_mod.__package__ ) return dist_name def _get_dist_name_from_mod_and_pkg_name( mod_name: str, pkg_name: str ) -> str | None: try: # Attempt to get the dist metadata directly using the module name distribution = get_dist_from_name(mod_name) return distribution.metadata["Name"] except PackageNotFoundError: return _get_dist_name_from_pkg_name(pkg_name) def _get_dist_name_from_pkg_name(pkg_name: str) -> str | None: # Fallback logic: For packages where the module name might not # directly match the distribution name, try finding distributions # that contain this module for dist in get_all_dists(): module_names = dist.read_text("top_level.txt") if module_names is not None and any( p in module_names.splitlines() for p in pkg_name.split(".") ): return dist.metadata["Name"] logger.warn( "Unable to match package name to any known distribution.", pkg_name=pkg_name, ) return None def monkey_patch_parser(parser: argparse.ArgumentParser) -> None: """Tweeks ArgumentParser a bit so it works better with clack.""" _patch_add_argument_method(parser) def _patch_add_argument_method(parser: argparse.ArgumentParser) -> None: def add_argument(*args: Any, **kwargs: Any) -> None: if "default" not in kwargs: field_name = _get_field_name(args, kwargs) config_defaults = dyn.get_config_defaults() default = config_defaults.get( field_name, ARGPARSE_ARGUMENT_DEFAULT ) kwargs["default"] = default argparse.ArgumentParser.add_argument(parser, *args, **kwargs) parser.add_argument = add_argument # type: ignore[assignment] def _get_field_name(args: Sequence[str], kwargs: Mapping[str, Any]) -> str: if dest := kwargs.get("dest", None): assert isinstance(dest, str) return dest if not args[0].startswith("-"): return args[0] long_opt = args[0] if not long_opt.startswith("--"): long_opt = args[1] return long_opt.lstrip("-").replace("-", "_") def _log_type_factory(app_name: str) -> Callable[[str], Log]: def log_type(arg: str) -> Log: # This regex will match arguments of the form 'FILE[:LEVEL][@FORMAT]'. pttrn = ( r"^(?P<file>[^:@]+)(?::(?P<level>[^:@]+))?(?:@(?P<format>[^:@]+))?" ) match = re.match(pttrn, arg) if not match: raise argparse.ArgumentTypeError( f"Bad log specification ({arg!r}). Must match the following" f" regular expression: {pttrn!r}" ) file = match.group("file") # If FILE is of the form '+[NAME]'... if file.startswith("+"): # Then we use a default logfile location. logfile_stem = file[1:] if not logfile_stem: logfile_stem = app_name file = str(get_default_logfile(logfile_stem)) # If `--log null` is specified on the command-line... if file == "null": # HACK: The intention here is to disable logging to the console # (i.e. 'stderr' or 'stdout'). The actual effect is that only # CRITICAL logging messages will get logged to stderr. This # approaches a real solution since CRITICAL is used so infrequently # in practice, but is not technically correct. return Log(file="stderr", format="nocolor", level="CRITICAL") format_ = cast(Optional[LogFormat], match.group("format")) # If format is unset and this is a console logger... if format_ is None and file in ["stdout", "stderr"]: format_ = "color" # Else if format is unset and this is a file logger... elif format_ is None: format_ = "json" level = cast(Optional[LogLevel], match.group("level")) if level is not None: level = cast(LogLevel, level.upper()) return Log(file=file, format=format_, level=level) return log_type def _get_package_location(file_path: str, package: str) -> str: file_parent = Path(file_path).parent result = str(file_parent) package_subpath = package.replace(".", "/") result = "".join(result.rsplit(package_subpath, 1)) result = _shorten_homedir(result) result = result.rstrip("/") return result def _shorten_homedir(path: str) -> str: home = str(Path.home()) return path.replace(home, "~") class _HelpFormatter(argparse.RawDescriptionHelpFormatter): """ Custom argparse.HelpFormatter that uses raw descriptions and sorts optional arguments alphabetically. """ def add_arguments(self, actions: Iterable[argparse.Action]) -> None: actions = sorted(actions, key=_argparse_action_key) super().add_arguments(actions) def _argparse_action_key(action: argparse.Action) -> str: opts = action.option_strings if opts: return opts[-1].lstrip("-") else: return action.dest