Source code for

# -*- coding: utf-8 -*-

"""Training loops for KGE models using multi-modal information.

======  ==========================================
Name    Reference
======  ==========================================
lcwa    :class:``
slcwa   :class:``
======  ==========================================

.. note:: This table can be re-generated with ``pykeen ls trainers -f rst``

from typing import Mapping, Set, Type, Union

from .lcwa import LCWATrainingLoop  # noqa: F401
from .slcwa import SLCWATrainingLoop  # noqa: F401
from .training_loop import NonFiniteLossError, TrainingLoop  # noqa: F401
from ..utils import get_cls, normalize_string

__all__ = [

_TRAINING_LOOPS: Set[Type[TrainingLoop]] = {

#: A mapping of training loops' names to their implementations
training_loops: Mapping[str, Type[TrainingLoop]] = {
    normalize_string(cls.__name__, suffix=_TRAINING_LOOP_SUFFIX): cls
    for cls in _TRAINING_LOOPS

[docs]def get_training_loop_cls(query: Union[None, str, Type[TrainingLoop]]) -> Type[TrainingLoop]: """Look up a training loop class by name (case/punctuation insensitive) in :data:``. :param query: The name of the training loop (case insensitive, punctuation insensitive). :return: The training loop class """ return get_cls( query, base=TrainingLoop, lookup_dict=training_loops, default=SLCWATrainingLoop, suffix=_TRAINING_LOOP_SUFFIX, )