Source code for pykeen.models.unimodal.rgcn

# -*- coding: utf-8 -*-

"""Implementation of the R-GCN model."""

from typing import Any, Mapping, Optional

import torch
from class_resolver import Hint
from torch import nn

from ..nbase import ERModel, EmbeddingSpecificationHint
from ...nn.emb import EmbeddingSpecification, RGCNRepresentations
from ...nn.message_passing import Decomposition
from ...nn.modules import Interaction, interaction_resolver
from ...nn.weighting import EdgeWeighting
from ...triples import CoreTriplesFactory
from ...typing import Initializer, RelationRepresentation

__all__ = [
    "RGCN",
]


[docs]class RGCN( ERModel[torch.FloatTensor, RelationRepresentation, torch.FloatTensor], ): """An implementation of R-GCN from [schlichtkrull2018]_. This model uses graph convolutions with relation-specific weights. .. seealso:: - `Pytorch Geometric"s implementation of R-GCN <https://github.com/rusty1s/pytorch_geometric/blob/1.3.2/examples/rgcn.py>`_ - `DGL"s implementation of R-GCN <https://github.com/dmlc/dgl/tree/v0.4.0/examples/pytorch/rgcn>`_ --- citation: author: Schlichtkrull year: 2018 link: https://arxiv.org/pdf/1703.06103 """ #: The default strategy for optimizing the model"s hyper-parameters hpo_default = dict( embedding_dim=dict(type=int, low=32, high=512, q=32), num_layers=dict(type=int, low=1, high=5, q=1), use_bias=dict(type="bool"), use_batch_norm=dict(type="bool"), activation_cls=dict(type="categorical", choices=[nn.ReLU, nn.LeakyReLU]), interaction=dict(type="categorical", choices=["distmult", "complex", "ermlp"]), edge_dropout=dict(type=float, low=0.0, high=.9), self_loop_dropout=dict(type=float, low=0.0, high=.9), edge_weighting=dict(type="categorical", choices=["inverse_in_degree", "inverse_out_degree", "symmetric"]), decomposition=dict(type="categorical", choices=["bases", "blocks"]), # TODO: Decomposition kwargs # num_bases=dict(type=int, low=2, high=100, q=1), # num_blocks=dict(type=int, low=2, high=20, q=1), ) def __init__( self, *, triples_factory: CoreTriplesFactory, embedding_dim: int = 500, num_layers: int = 2, # https://github.com/MichSchli/RelationPrediction/blob/c77b094fe5c17685ed138dae9ae49b304e0d8d89/code/encoders/affine_transform.py#L24-L28 base_entity_initializer: Hint[Initializer] = nn.init.xavier_uniform_, base_entity_initializer_kwargs: Optional[Mapping[str, Any]] = None, relation_initializer: Hint[Initializer] = nn.init.xavier_uniform_, relation_initializer_kwargs: Optional[Mapping[str, Any]] = None, relation_representations: EmbeddingSpecificationHint = None, interaction: Interaction[torch.FloatTensor, RelationRepresentation, torch.FloatTensor], interaction_kwargs: Optional[Mapping[str, Any]] = None, use_bias: bool = True, use_batch_norm: bool = False, activation: Hint[nn.Module] = None, activation_kwargs: Optional[Mapping[str, Any]] = None, edge_dropout: float = 0.4, self_loop_dropout: float = 0.2, edge_weighting: Hint[EdgeWeighting] = None, decomposition: Hint[Decomposition] = None, decomposition_kwargs: Optional[Mapping[str, Any]] = None, **kwargs, ): # create enriched entity representations entity_representations = RGCNRepresentations( triples_factory=triples_factory, embedding_specification=EmbeddingSpecification( embedding_dim=embedding_dim, initializer=base_entity_initializer, initializer_kwargs=base_entity_initializer_kwargs, ), num_layers=num_layers, use_bias=use_bias, use_batch_norm=use_batch_norm, activation=activation, activation_kwargs=activation_kwargs, edge_dropout=edge_dropout, self_loop_dropout=self_loop_dropout, edge_weighting=edge_weighting, decomposition=decomposition, decomposition_kwargs=decomposition_kwargs, ) # Resolve interaction function interaction = interaction_resolver.make(query=interaction, pos_kwargs=interaction_kwargs) # set default relation representation if relation_representations is None: relation_representations = EmbeddingSpecification( shape=entity_representations.shape, initializer=relation_initializer, initializer_kwargs=relation_initializer_kwargs, ) super().__init__( entity_representations=entity_representations, relation_representations=relation_representations, triples_factory=triples_factory, interaction=interaction, **kwargs, )