Source code for pykeen.models.unimodal.rgcn

"""Implementation of the R-GCN model."""

from collections.abc import Mapping
from typing import Any

from class_resolver import Hint, HintOrType
from torch import nn

from ..nbase import ERModel
from ...constants import DEFAULT_DROPOUT_HPO_RANGE, DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...nn.message_passing import Decomposition, RGCNRepresentation
from ...nn.modules import Interaction
from ...nn.representation import Representation
from ...nn.weighting import EdgeWeighting
from ...regularizers import Regularizer
from ...triples import CoreTriplesFactory
from ...typing import FloatTensor, Initializer, RelationRepresentation

__all__ = [
    "RGCN",
]


[docs] class RGCN(ERModel[FloatTensor, RelationRepresentation, FloatTensor]): r"""An implementation of R-GCN from [schlichtkrull2018]_. The Relational Graph Convolutional Network (R-GCN) comprises three parts: 1. A GCN-based entity encoder that computes enriched representations for entities, cf. :class:`pykeen.nn.message_passing.RGCNRepresentations`. The representation for entity $i$ at level $l \in (1,\dots,L)$ is denoted as $\textbf{e}_i^l$. The GCN is modified to use different weights depending on the type of the relation. 2. Relation representations $\textbf{R}_{r} \in \mathbb{R}^{d \times d}$ is a diagonal matrix that are learned independently from the GCN-based encoder. 3. An arbitrary interaction model which computes the plausibility of facts given the enriched representations, cf. :class:`pykeen.nn.modules.Interaction`. Scores for each triple $(h,r,t) \in \mathcal{K}$ are calculated by using the representations in the final level of the GCN-based encoder $\textbf{e}_h^L$ and $\textbf{e}_t^L$ along with relation representation $\textbf{R}_{r}$. While the original implementation of R-GCN used the DistMult model and we use it as a default, this implementation allows the specification of an arbitrary interaction model. .. math:: f(h,r,t) = \textbf{e}_h^L \textbf{R}_{r} \textbf{e}_t^L .. seealso:: - `PyTorch Geometric's implementation of R-GCN <https://github.com/rusty1s/pytorch_geometric/blob/1.3.2/examples/rgcn.py>`_ - `DGL's implementation of R-GCN <https://github.com/dmlc/dgl/tree/v0.4.0/examples/pytorch/rgcn>`_ --- name: R-GCN citation: author: Schlichtkrull year: 2018 link: https://arxiv.org/pdf/1703.06103 github: https://github.com/MichSchli/RelationPrediction """ #: The default strategy for optimizing the model's hyper-parameters hpo_default = { "embedding_dim": DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE, "num_layers": {"type": int, "low": 1, "high": 5, "q": 1}, "use_bias": {"type": "bool"}, "activation": {"type": "categorical", "choices": [nn.ReLU, nn.LeakyReLU]}, "interaction": {"type": "categorical", "choices": ["distmult", "complex", "ermlp"]}, "edge_dropout": DEFAULT_DROPOUT_HPO_RANGE, "self_loop_dropout": DEFAULT_DROPOUT_HPO_RANGE, "edge_weighting": {"type": "categorical", "choices": ["inverse_in_degree", "inverse_out_degree", "symmetric"]}, "decomposition": {"type": "categorical", "choices": ["bases", "block"]}, # TODO: Decomposition kwargs # num_bases=dict(type=int, low=2, high=100, q=1), # num_blocks=dict(type=int, low=2, high=20, q=1), } def __init__( self, *, triples_factory: CoreTriplesFactory, embedding_dim: int = 500, num_layers: int = 2, # https://github.com/MichSchli/RelationPrediction/blob/c77b094fe5c17685ed138dae9ae49b304e0d8d89/code/encoders/affine_transform.py#L24-L28 base_entity_initializer: Hint[Initializer] = nn.init.xavier_uniform_, base_entity_initializer_kwargs: Mapping[str, Any] | None = None, relation_representations: HintOrType[Representation] = None, relation_initializer: Hint[Initializer] = nn.init.xavier_uniform_, relation_initializer_kwargs: Mapping[str, Any] | None = None, interaction: HintOrType[Interaction[FloatTensor, RelationRepresentation, FloatTensor]] = "DistMult", interaction_kwargs: Mapping[str, Any] | None = None, use_bias: bool = True, activation: Hint[nn.Module] = None, activation_kwargs: Mapping[str, Any] | None = None, edge_dropout: float = 0.4, self_loop_dropout: float = 0.2, edge_weighting: Hint[EdgeWeighting] = None, decomposition: Hint[Decomposition] = None, decomposition_kwargs: Mapping[str, Any] | None = None, regularizer: Hint[Regularizer] = None, regularizer_kwargs: Mapping[str, Any] | None = None, **kwargs, ): """ Initialize the model. :param triples_factory: the (training) triples factory :param embedding_dim: the embedding dimension :param num_layers: >0 the number of layers :param base_entity_initializer: the entity base representation initializer :param base_entity_initializer_kwargs: the entity base representation initializer's keyword-based parameters :param relation_representations: the relation representations, or a hint thereof :param relation_initializer: the entity base representation initializer :param relation_initializer_kwargs: the entity base representation initializer's keyword-based parameters :param interaction: the interaction function, or a hint thereof :param interaction_kwargs: additional keyword-based parameters passed to the interaction function :param use_bias: whether to use a bias on the message passing layers :param activation: the activation function, or a hint thereof :param activation_kwargs: additional keyword-based parameters passed to the activation function :param edge_dropout: the edge dropout, except for self-loops :param self_loop_dropout: the self-loop dropout :param edge_weighting: the edge weighting :param decomposition: the convolution weight decomposition :param decomposition_kwargs: additional keyword-based parameters passed to the weight decomposition :param regularizer: the regularizer applied to the base representations :param regularizer_kwargs: additional keyword-based parameters passed to the regularizer :param kwargs: additional keyword-based parameters passed to :meth:`ERModel.__init__` """ super().__init__( entity_representations=RGCNRepresentation, entity_representations_kwargs={ "triples_factory": triples_factory, "entity_representations_kwargs": { "shape": embedding_dim, "initializer": base_entity_initializer, "initializer_kwargs": base_entity_initializer_kwargs, }, "num_layers": num_layers, "use_bias": use_bias, "activation": activation, "activation_kwargs": activation_kwargs, "edge_dropout": edge_dropout, "self_loop_dropout": self_loop_dropout, "edge_weighting": edge_weighting, "decomposition": decomposition, "decomposition_kwargs": decomposition_kwargs, # cf. https://github.com/MichSchli/RelationPrediction/blob/c77b094fe5c17685ed138dae9ae49b304e0d8d89/code/decoders/bilinear_diag.py#L64-L67 # noqa: E501 "regularizer": regularizer, "regularizer_kwargs": regularizer_kwargs, }, relation_representations=relation_representations, relation_representations_kwargs={ "shape": embedding_dim, "initializer": relation_initializer, "initializer_kwargs": relation_initializer_kwargs, # cf. https://github.com/MichSchli/RelationPrediction/blob/c77b094fe5c17685ed138dae9ae49b304e0d8d89/code/decoders/bilinear_diag.py#L64-L67 # noqa: E501 "regularizer": regularizer, "regularizer_kwargs": regularizer_kwargs, }, triples_factory=triples_factory, interaction=interaction, interaction_kwargs=interaction_kwargs, **kwargs, )