Source code for class_resolver.api

# -*- coding: utf-8 -*-

"""Resolve classes."""

from typing import Any, Collection, Generic, Iterable, Mapping, Optional, Type, TypeVar, Union

__all__ = [
    'Hint',
    'Resolver',
    'get_subclasses',
    'get_cls',
    'normalize_string',
]

X = TypeVar('X')
Hint = Union[None, str, X]


class Resolver(Generic[X]):
    """Resolve from a list of classes."""

    def __init__(
        self,
        classes: Collection[Type[X]],
        *,
        base: Type[X],
        default: Optional[Type[X]] = None,
        suffix: Optional[str] = None,
        synonyms: Optional[Mapping[str, Type[X]]] = None,
    ) -> None:
        """Initialize the resolver.

        :param classes: A list of classes
        :param base: The base class
        :param default: The default class
        :param suffix: The optional shared suffix of all classes. If None, use the base class' name for it. To disable
            this behaviour, explicitly provide `suffix=""`.
        :param synonyms: The optional synonym dictionary
        """
        self.base = base
        self.default = default
        if suffix is None:
            suffix = normalize_string(base.__name__)
        self.suffix = suffix
        self.synonyms = synonyms
        self.lookup_dict = {
            self.normalize_cls(cls): cls
            for cls in classes
        }

    @classmethod
    def from_subclasses(cls, base: Type[X], *, skip: Optional[Collection[Type[X]]] = None, **kwargs) -> 'Resolver':
        """Make a resolver from the subclasses of a given class.

        :param base: The base class whose subclasses will be indexed
        :param skip: Any subclasses to skip (usually good to hardcode intermediate base classes)
        :param kwargs: remaining keyword arguments to pass to :func:`Resolver.__init__`
        :return: A resolver instance
        """
        skip = set(skip) if skip else set()
        return Resolver(
            {
                subcls
                for subcls in get_subclasses(base)
                if subcls not in skip
            },
            base=base,
            **kwargs,
        )

    def normalize_inst(self, x: X) -> str:
        """Normalize the class name of the instance."""
        return self.normalize_cls(x.__class__)

    def normalize_cls(self, cls: Type[X]) -> str:
        """Normalize the class name."""
        return self.normalize(cls.__name__)

    def normalize(self, s: str) -> str:
        """Normalize the string with this resolve's suffix."""
        return normalize_string(s, suffix=self.suffix)

    def lookup(self, query: Hint[Type[X]]) -> Type[X]:
        """Lookup a class."""
        return get_cls(
            query,
            base=self.base,
            lookup_dict=self.lookup_dict,
            lookup_dict_synonyms=self.synonyms,
            default=self.default,
            suffix=self.suffix,
        )

    def make(self, query: Hint[Union[X, Type[X]]], pos_kwargs: Optional[Mapping[str, Any]] = None, **kwargs) -> X:
        """Instantiate a class with optional kwargs."""
        if query is None or isinstance(query, (str, type)):
            cls: Type[X] = self.lookup(query)
            return cls(**(pos_kwargs or {}), **kwargs)  # type: ignore

        # An instance was passed, and it will go through without modification.
        return query

    def get_option(self, *flags: str, default: Optional[str] = None, **kwargs):
        """Get a click option for this resolver."""
        if default is None:
            if self.default is None:
                raise ValueError
            default = self.normalize_cls(self.default)

        import click

        return click.option(
            *flags,
            type=click.Choice(list(self.lookup_dict)),
            default=default,
            show_default=True,
            callback=_make_callback(self.lookup),
            **kwargs,
        )


def _not_hint(x: Any) -> bool:
    return x is not None and not isinstance(x, (str, type))


def get_subclasses(cls: Type[X]) -> Iterable[Type[X]]:
    """Get all subclasses.

    :param cls: The ancestor class
    :yields: Descendant classes of the ancestor class
    """
    for subclass in cls.__subclasses__():
        yield from get_subclasses(subclass)
        yield subclass


def get_cls(
    query: Union[None, str, Type[X]],
    base: Type[X],
    lookup_dict: Mapping[str, Type[X]],
    lookup_dict_synonyms: Optional[Mapping[str, Type[X]]] = None,
    default: Optional[Type[X]] = None,
    suffix: Optional[str] = None,
) -> Type[X]:
    """Get a class by string, default, or implementation."""
    if query is None:
        if default is None:
            raise ValueError(f'No default {base.__name__} set')
        return default
    elif not isinstance(query, (str, type)):
        raise TypeError(f'Invalid {base.__name__} type: {type(query)} - {query}')
    elif isinstance(query, str):
        key = normalize_string(query, suffix=suffix)
        if key in lookup_dict:
            return lookup_dict[key]
        if lookup_dict_synonyms is not None and key in lookup_dict_synonyms:
            return lookup_dict_synonyms[key]
        valid_choices = sorted(set(lookup_dict.keys()).union(lookup_dict_synonyms or []))
        raise ValueError(f'Invalid {base.__name__} name: {query}. Valid choices are: {valid_choices}')
    elif issubclass(query, base):
        return query
    raise TypeError(f'Not subclass of {base.__name__}: {query}')


[docs]def normalize_string(s: str, *, suffix: Optional[str] = None) -> str: """Normalize a string for lookup.""" s = s.lower().replace('-', '').replace('_', '').replace(' ', '') if suffix is not None and s.endswith(suffix.lower()): return s[:-len(suffix)] return s
def _make_callback(f): def _callback(_, __, value): return f(value) return _callback