Source code for pykeen.models

# -*- coding: utf-8 -*-

"""Implementations of various knowledge graph embedding models.

===================  ==========================================  ====================
Name                 Reference                                   Citation
===================  ==========================================  ====================
ComplEx              :class:`pykeen.models.ComplEx`              [trouillon2016]_
ComplExLiteral       :class:`pykeen.models.ComplExLiteral`       [agustinus2018]_
ConvE                :class:`pykeen.models.ConvE`                [dettmers2018]_
ConvKB               :class:`pykeen.models.ConvKB`               [nguyen2018]_
DistMult             :class:`pykeen.models.DistMult`             [yang2014]_
DistMultLiteral      :class:`pykeen.models.DistMultLiteral`      [agustinus2018]_
ERMLP                :class:`pykeen.models.ERMLP`                [dong2014]_
ERMLPE               :class:`pykeen.models.ERMLPE`               [sharifzadeh2019]_
HolE                 :class:`pykeen.models.HolE`                 [nickel2016]_
KG2E                 :class:`pykeen.models.KG2E`                 [he2015]_
NTN                  :class:`pykeen.models.NTN`                  [socher2013]_
ProjE                :class:`pykeen.models.ProjE`                [shi2017]_
RESCAL               :class:`pykeen.models.RESCAL`               [nickel2011]_
RGCN                 :class:`pykeen.models.RGCN`                 [schlichtkrull2018]_
RotatE               :class:`pykeen.models.RotatE`               [sun2019]_
SimplE               :class:`pykeen.models.SimplE`               [kazemi2018]_
StructuredEmbedding  :class:`pykeen.models.StructuredEmbedding`  [bordes2011]_
TransD               :class:`pykeen.models.TransD`               [ji2015]_
TransE               :class:`pykeen.models.TransE`               [bordes2013]_
TransH               :class:`pykeen.models.TransH`               [wang2014]_
TransR               :class:`pykeen.models.TransR`               [lin2015]_
TuckER               :class:`pykeen.models.TuckER`               [balazevic2019]_
UnstructuredModel    :class:`pykeen.models.UnstructuredModel`    [bordes2014]_
===================  ==========================================  ====================

.. note:: This table can be re-generated with ``pykeen ls models -f rst``
"""

from typing import Mapping, Set, Type, Union

from .base import EntityEmbeddingModel, EntityRelationEmbeddingModel, Model
from .multimodal import ComplExLiteral, DistMultLiteral, MultimodalModel
from .unimodal import (
    ComplEx,
    ConvE,
    ConvKB,
    DistMult,
    ERMLP,
    ERMLPE,
    HolE,
    KG2E,
    NTN,
    ProjE,
    RESCAL,
    RGCN,
    RotatE,
    SimplE,
    StructuredEmbedding,
    TransD,
    TransE,
    TransH,
    TransR,
    TuckER,
    UnstructuredModel,
)
from ..utils import get_cls, normalize_string

__all__ = [
    'ComplEx',
    'ComplExLiteral',
    'ConvE',
    'ConvKB',
    'DistMult',
    'DistMultLiteral',
    'ERMLP',
    'ERMLPE',
    'HolE',
    'KG2E',
    'NTN',
    'ProjE',
    'RESCAL',
    'RGCN',
    'RotatE',
    'SimplE',
    'StructuredEmbedding',
    'TransD',
    'TransE',
    'TransH',
    'TransR',
    'TuckER',
    'UnstructuredModel',
    'models',
    'get_model_cls',
]


def _recur(c):
    for sc in c.__subclasses__():
        yield sc
        yield from _recur(sc)


_MODELS: Set[Type[Model]] = {
    cls
    for cls in _recur(Model)
    if cls not in {Model, MultimodalModel, EntityRelationEmbeddingModel, EntityEmbeddingModel}
}

#: A mapping of models' names to their implementations
models: Mapping[str, Type[Model]] = {
    normalize_string(cls.__name__): cls
    for cls in _MODELS
}


[docs]def get_model_cls(query: Union[str, Type[Model]]) -> Type[Model]: """Get the model class.""" return get_cls( query, base=Model, lookup_dict=models, )