"""Various decompositions for R-GCN."""
import logging
import math
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Optional
import torch
from class_resolver import ClassResolver, Hint, HintOrType, OptionalKwargs
from class_resolver.contrib.torch import activation_resolver
from docdata import parse_docdata
from torch import nn
from .init import uniform_norm_p1_, xavier_normal_
from .representation import LowRankRepresentation, Representation
from .utils import ShapeError, adjacency_tensor_to_stacked_matrix, use_horizontal_stacking
from .weighting import EdgeWeighting, edge_weight_resolver
from ..triples import CoreTriplesFactory
from ..typing import FloatTensor, LongTensor
from ..utils import ExtraReprMixin, einsum
__all__ = [
"RGCNRepresentation",
"RGCNLayer",
"Decomposition",
"BasesDecomposition",
"BlockDecomposition",
"decomposition_resolver",
]
logger = logging.getLogger(__name__)
[docs]
class Decomposition(nn.Module, ExtraReprMixin, ABC):
r"""Base module for relation-specific message passing.
A decomposition module implementation offers a way to reduce the number of parameters needed by learning
independent $d^2$ matrices for each relation. In R-GCN, the two proposed variants are treated as
hyper-parameters, and for different datasets different decompositions are superior in performance.
The decomposition module itself does not compute the full matrix from the factors, but rather provides efficient
means to compute the product of the factorized matrix with the source nodes' latent features to construct the
messages. This is usually more efficient than constructing the full matrices.
For an intuition, you can think about a simple low-rank matrix factorization of rank `1`, where $W = w w^T$
for a $d$-dimensional vector `w`. Then, computing $Wv$ as $(w w^T) v$ gives you an intermediate result of size
$d \times d$, while you can also compute $w(w^Tv)$, where the intermediate result is just a scalar.
The implementations use the efficient version based on adjacency tensor stacking from [thanapalasingam2021]_.
The adjacency tensor is reshaped into a sparse matrix to support message passing by a
single sparse matrix multiplication, cf. :func:`pykeen.nn.utils.adjacency_tensor_to_stacked_matrix`.
.. note ::
this module does neither take care of the self-loop, nor of applying an activation function.
"""
def __init__(
self,
num_relations: int,
input_dim: int = 32,
output_dim: Optional[int] = None,
):
"""Initialize the layer.
:param num_relations: >0
The number of relations.
:param input_dim: >0
The input dimension.
:param output_dim: >0
The output dimension. If None is given, defaults to input_dim.
"""
super().__init__()
# input normalization
if output_dim is None:
output_dim = input_dim
self.input_dim = input_dim
self.num_relations = num_relations
self.output_dim = output_dim
[docs]
def forward(
self,
x: FloatTensor,
source: LongTensor,
target: LongTensor,
edge_type: LongTensor,
edge_weights: Optional[FloatTensor] = None,
accumulator: Optional[FloatTensor] = None,
) -> FloatTensor:
"""Relation-specific message passing from source to target.
:param x: shape: (num_nodes, input_dim)
The node representations.
:param source: shape: (num_edges,)
The source indices.
:param target: shape: (num_edges,)
The target indices.
:param edge_type: shape: (num_edges,)
The edge types.
:param edge_weights: shape: (num_edges,)
Precomputed edge weights.
:param accumulator: shape: (num_nodes, output_dim)
a pre-allocated output accumulator. may be used if multiple different message passing steps are performed
and accumulated by sum. If none is given, create an accumulator filled with zeroes.
:return: shape: (num_nodes, output_dim)
The enriched node embeddings.
"""
horizontal = use_horizontal_stacking(input_dim=self.input_dim, output_dim=self.output_dim)
adj = adjacency_tensor_to_stacked_matrix(
num_relations=self.num_relations,
num_entities=x.shape[0],
source=source,
target=target,
edge_type=edge_type,
edge_weights=edge_weights,
horizontal=horizontal,
)
if horizontal:
x = self.forward_horizontally_stacked(x=x, adj=adj)
else:
x = self.forward_vertically_stacked(x=x, adj=adj)
if accumulator is not None:
x = accumulator + x
return x
[docs]
@abstractmethod
def forward_horizontally_stacked(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
"""
Forward pass for horizontally stacked adjacency.
:param x: shape: `(num_entities, input_dim)`
the input entity representations
:param adj: shape: `(num_entities, num_relations * num_entities)`, sparse
the horizontally stacked adjacency matrix
:return: shape: `(num_entities, output_dim)`
the updated entity representations.
"""
raise NotImplementedError
[docs]
@abstractmethod
def forward_vertically_stacked(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
"""
Forward pass for vertically stacked adjacency.
:param x: shape: `(num_entities, input_dim)`
the input entity representations
:param adj: shape: `(num_entities * num_relations, num_entities)`, sparse
the vertically stacked adjacency matrix
:return: shape: `(num_entities, output_dim)`
the updated entity representations.
"""
raise NotImplementedError
[docs]
def reset_parameters(self):
"""Reset the layer's parameters."""
# note: the base class does not have any parameters
[docs]
class BasesDecomposition(Decomposition):
r"""
Represent relation-weights as a linear combination of base transformation matrices.
The basis decomposition represents the relation-specific transformation matrices
as a weighted combination of base matrices, $\{\mathbf{B}_i^l\}_{i=1}^{B}$, i.e.,
.. math::
\mathbf{W}_r^l = \sum \limits_{b=1}^B \alpha_{rb} \mathbf{B}^l_i
The implementation uses a reshaping of the adjacency tensor into a sparse matrix to support message passing by a
single sparse matrix multiplication, cf. [thanapalasingam2021]_.
.. seealso ::
https://github.com/thiviyanT/torch-rgcn/blob/267faffd09a441d902c483a8c130410c72910e90/torch_rgcn/layers.py#L450-L565
"""
def __init__(self, num_bases: Optional[int] = None, **kwargs):
"""
Initialize the bases decomposition.
:param num_bases:
the number of bases
:param kwargs:
additional keyword-based parameters passed to :meth:`Decomposition.__init__`
"""
super().__init__(**kwargs)
# Heuristic for default value
if num_bases is None:
num_bases = math.ceil(math.sqrt(self.num_relations))
logger.info(f"No num_bases was provided. Using sqrt(num_relations)={num_bases}.")
if num_bases > self.num_relations:
logger.warning(f"The number of bases ({num_bases}) exceeds the number of relations ({self.num_relations}).")
self.relation_representations = LowRankRepresentation(
max_id=self.num_relations,
shape=(self.input_dim, self.output_dim),
num_bases=num_bases,
weight_initializer=uniform_norm_p1_,
initializer=nn.init.xavier_normal_,
)
# docstr-coverage: inherited
@property
def bases(self) -> torch.Tensor:
"""Return the base representations."""
return self.relation_representations.bases(indices=None)
@property
def base_weights(self) -> torch.Tensor:
"""Return the base weights."""
return self.relation_representations.weight
# docstr-coverage: inherited
[docs]
def reset_parameters(self): # noqa: D102
# note: the only parameters are inside the relation representation module, which has its own reset_parameters
pass
# docstr-coverage: inherited
[docs]
def forward_horizontally_stacked(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: # noqa: D102
x = einsum("ni, rb, bio -> rno", x, self.base_weights, self.bases)
# TODO: can we change the dimension order to make this contiguous?
return torch.spmm(adj, x.reshape(-1, self.output_dim))
# docstr-coverage: inherited
[docs]
def forward_vertically_stacked(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: # noqa: D102
x = torch.spmm(adj, x)
x = x.view(self.num_relations, -1, self.input_dim)
return einsum("rb, bio, rni -> no", self.base_weights, self.bases, x)
def _make_dim_divisible(dim: int, divisor: int, name: str) -> int:
dim_div, remainder = divmod(dim, divisor)
if remainder:
logger.warning(f"{name}={dim} not divisible by {divisor}.")
dim = dim_div * divisor
assert dim % divisor == 0
return dim
def _pad_if_necessary(x: torch.Tensor, dim: int) -> torch.Tensor:
"""Apply padding if necessary."""
padding_dim = dim - x.shape[-1]
if padding_dim < 0:
raise ValueError("Cannot have a negative padding")
if padding_dim == 0:
return x
return torch.cat([x, x.new_zeros(*x.shape[:-1], padding_dim)], dim=-1)
def _unpad_if_necessary(x: torch.Tensor, dim: int) -> torch.Tensor:
"""Remove padding if necessary."""
padding_dim = dim - x.shape[-1]
if padding_dim < 0:
raise ValueError("Cannot have a negative padding")
if padding_dim == 0:
return x
return x[..., :-padding_dim]
[docs]
class BlockDecomposition(Decomposition):
r"""
Represent relation-specific weight matrices via block-diagonal matrices.
The block-diagonal decomposition restricts each transformation matrix to a block-diagonal-matrix, i.e.,
.. math::
\mathbf{W}_r^l = diag(\mathbf{B}_{r,1}^l, \ldots, \mathbf{B}_{r,B}^l)
where $\mathbf{B}_{r,i} \in \mathbb{R}^{(d^{(l) }/ B) \times (d^{(l)} / B)}$.
The implementation is based on the efficient version of [thanapalasingam2021]_, which uses a reshaping of the
adjacency tensor into a sparse matrix to support message passing by a single sparse matrix multiplication.
.. seealso ::
https://github.com/thiviyanT/torch-rgcn/blob/267faffd09a441d902c483a8c130410c72910e90/torch_rgcn/layers.py#L450-L565
"""
def __init__(self, num_blocks: Optional[int] = None, **kwargs):
"""
Initialize the layer.
:param num_blocks:
the number of blocks.
:param kwargs:
keyword-based parameters passed to :meth:`Decomposition.__init__`.
"""
super().__init__(**kwargs)
# normalize num blocks
if num_blocks is None:
num_blocks = math.gcd(self.input_dim, self.output_dim)
logger.info(f"Inferred num_blocks={num_blocks} by GCD heuristic.")
self.num_blocks = num_blocks
# determine necessary padding
self.padded_input_dim = _make_dim_divisible(dim=self.input_dim, divisor=num_blocks, name="input_dim")
self.padded_output_dim = _make_dim_divisible(dim=self.output_dim, divisor=num_blocks, name="output_dim")
# determine block sizes
self.input_block_size = self.padded_input_dim // num_blocks
self.output_block_size = self.padded_output_dim // num_blocks
# (R, nb, bsi, bso)
self.blocks = nn.Parameter(
data=torch.empty(
self.num_relations,
num_blocks,
self.input_block_size,
self.output_block_size,
),
requires_grad=True,
)
self.reset_parameters()
[docs]
def reset_parameters(self):
"""Reset the layer's parameters."""
xavier_normal_(self.blocks.data)
# docstr-coverage: inherited
# docstr-coverage: inherited
[docs]
def forward_horizontally_stacked(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: # noqa: D102
# apply padding if necessary
x = _pad_if_necessary(x=x, dim=self.padded_input_dim)
# (n, di) -> (n, nb, bsi)
x = x.view(x.shape[0], self.num_blocks, self.input_block_size)
# (n, nb, bsi), (R, nb, bsi, bso) -> (R, n, nb, bso)
x = einsum("nbi, rbio -> rnbo", x, self.blocks)
# (R, n, nb, bso) -> (R * n, do)
# note: depending on the contracting order, the output may supporting viewing, or not
x = x.reshape(-1, self.num_blocks * self.output_block_size)
# (n, R * n), (R * n, do) -> (n, do)
x = torch.sparse.mm(adj, x)
# remove padding if necessary
return _unpad_if_necessary(x=x, dim=self.padded_output_dim)
# docstr-coverage: inherited
[docs]
def forward_vertically_stacked(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: # noqa: D102
# apply padding if necessary
x = _pad_if_necessary(x=x, dim=self.padded_input_dim)
# (R * n, n), (n, di) -> (R * n, di)
x = torch.sparse.mm(adj, x)
# (R * n, di) -> (R, n, nb, bsi)
x = x.view(self.num_relations, -1, self.num_blocks, self.input_block_size)
# (R, nb, bsi, bso), (R, n, nb, bsi) -> (n, nb, bso)
x = einsum("rbio, rnbi -> nbo", self.blocks, x)
# (n, nb, bso) -> (n, do)
# note: depending on the contracting order, the output may supporting viewing, or not
x = x.reshape(x.shape[0], self.num_blocks * self.output_block_size)
# remove padding if necessary
return _unpad_if_necessary(x=x, dim=self.padded_output_dim)
[docs]
class RGCNLayer(nn.Module):
r"""
An RGCN layer from [schlichtkrull2018]_ updated to match the official implementation.
This layer uses separate decompositions for forward and backward edges (i.e., "normal" and implicitly created
inverse relations), as well as a separate transformation for self-loops.
Ignoring dropouts, decomposition and normalization, it can be written as
.. math ::
y_i = \sigma(
W^s x_i
+ \sum_{(e_j, r, e_i) \in \mathcal{T}} W^f_r x_j
+ \sum_{(e_i, r, e_j) \in \mathcal{T}} W^b_r x_j
+ b
)
where $b, W^s, W^f_r, W^b_r$ are trainable weights. $W^f_r, W^b_r$ are relation-specific, and commonly enmploy a
weight-sharing mechanism, cf. Decomposition. $\sigma$ is an activation function. The individual terms in both sums
are typically weighted. This is implemented by EdgeWeighting. Moreover, RGCN employs an edge-dropout, however,
this needs to be done outside of an individual layer, since the same edges are dropped across all layers. In
contrast, the self-loop dropout is layer-specific.
"""
def __init__(
self,
num_relations: int,
input_dim: int = 32,
output_dim: Optional[int] = None,
use_bias: bool = True,
activation: Hint[nn.Module] = None,
activation_kwargs: Optional[Mapping[str, Any]] = None,
self_loop_dropout: float = 0.2,
decomposition: Hint[Decomposition] = None,
decomposition_kwargs: Optional[Mapping[str, Any]] = None,
):
"""
Initialize the layer.
:param input_dim: >0
the input dimension
:param num_relations:
the number of relations
:param output_dim: >0
the output dimension. If none is given, use the input dimension.
:param use_bias:
whether to use a trainable bias
:param activation:
the activation function to use. Defaults to None, i.e., the identity function serves as activation.
:param activation_kwargs:
additional keyword-based arguments passed to the activation function for instantiation
:param self_loop_dropout: 0 <= self_loop_dropout <= 1
the dropout to use for self-loops
:param decomposition:
the decomposition to use, cf. Decomposition and decomposition_resolver
:param decomposition_kwargs:
the keyword-based arguments passed to the decomposition for instantiation
"""
super().__init__()
# cf. https://github.com/MichSchli/RelationPrediction/blob/c77b094fe5c17685ed138dae9ae49b304e0d8d89/code/encoders/message_gcns/gcn_basis.py#L22-L24 # noqa: E501
# there are separate decompositions for forward and backward relations.
# the self-loop weight is not decomposed.
self.fwd = decomposition_resolver.make(
query=decomposition,
pos_kwargs=decomposition_kwargs,
input_dim=input_dim,
output_dim=output_dim,
num_relations=num_relations,
)
self.bwd = decomposition_resolver.make(
query=decomposition,
pos_kwargs=decomposition_kwargs,
input_dim=input_dim,
output_dim=output_dim,
num_relations=num_relations,
)
self.self_loop = nn.Linear(in_features=input_dim, out_features=self.fwd.output_dim, bias=use_bias)
self.dropout = nn.Dropout(p=self_loop_dropout)
self.activation = activation_resolver.make_safe(query=activation, pos_kwargs=activation_kwargs)
[docs]
def forward(
self,
x: FloatTensor,
source: LongTensor,
target: LongTensor,
edge_type: LongTensor,
edge_weights: Optional[FloatTensor] = None,
) -> FloatTensor:
"""
Calculate enriched entity representations.
:param x: shape: (num_entities, input_dim)
The input entity representations.
:param source: shape: (num_triples,)
The indices of the source entity per triple.
:param target: shape: (num_triples,)
The indices of the target entity per triple.
:param edge_type: shape: (num_triples,)
The relation type per triple.
:param edge_weights: shape: (num_triples,)
Scalar edge weights per triple.
:return: shape: (num_entities, output_dim)
Enriched entity representations.
"""
# TODO: we could cache the stacked adjacency matrices
# self-loop
y = self.dropout(self.self_loop(x))
# forward messages
y = self.fwd(x=x, source=source, target=target, edge_type=edge_type, edge_weights=edge_weights, accumulator=y)
# backward messages
y = self.bwd(x=x, source=target, target=source, edge_type=edge_type, edge_weights=edge_weights, accumulator=y)
# activation
if self.activation is not None:
y = self.activation(y)
return y
decomposition_resolver: ClassResolver[Decomposition] = ClassResolver.from_subclasses(
base=Decomposition, default=BasesDecomposition
)
[docs]
@parse_docdata
class RGCNRepresentation(Representation):
r"""Entity representations enriched by R-GCN.
The GCN employed by the entity encoder is adapted to include typed edges.
The forward pass of the GCN is defined by:
.. math::
\textbf{e}_{i}^{l+1} = \sigma \left( \sum_{r \in \mathcal{R}}\sum_{j\in \mathcal{N}_{i}^{r}}
\frac{1}{c_{i,r}} \textbf{W}_{r}^{l} \textbf{e}_{j}^{l} + \textbf{W}_{0}^{l} \textbf{e}_{i}^{l}\right)
where $\mathcal{N}_{i}^{r}$ is the set of neighbors of node $i$ that are connected to
$i$ by relation $r$, $c_{i,r}$ is a fixed normalization constant (but it can also be introduced as an additional
parameter), and $\textbf{W}_{r}^{l} \in \mathbb{R}^{d^{(l)} \times d^{(l)}}$ and
$\textbf{W}_{0}^{l} \in \mathbb{R}^{d^{(l)} \times d^{(l)}}$ are weight matrices of the `l`-th layer of the
R-GCN.
The encoder aggregates for each node $e_i$ the latent representations of its neighbors and its
own latent representation $e_{i}^{l}$ into a new latent representation $e_{i}^{l+1}$.
In contrast to standard GCN, R-GCN defines relation specific transformations
$\textbf{W}_{r}^{l}$ which depend on the type and direction of an edge.
Since having one matrix for each relation introduces a large number of additional parameters, the authors instead
propose to use a decomposition, cf. :class:`pykeen.nn.message_passing.Decomposition`.
---
name: R-GCN
citation:
author: Schlichtkrull
year: 2018
link: https://arxiv.org/pdf/1703.06103
github: https://github.com/MichSchli/RelationPrediction
"""
def __init__(
self,
triples_factory: CoreTriplesFactory,
max_id: Optional[int] = None,
shape: Optional[Sequence[int]] = None,
entity_representations: HintOrType[Representation] = None,
entity_representations_kwargs: OptionalKwargs = None,
num_layers: int = 2,
use_bias: bool = True,
activation: Hint[nn.Module] = None,
activation_kwargs: Optional[Mapping[str, Any]] = None,
edge_dropout: float = 0.4,
self_loop_dropout: float = 0.2,
edge_weighting: Hint[EdgeWeighting] = None,
decomposition: Hint[Decomposition] = None,
decomposition_kwargs: Optional[Mapping[str, Any]] = None,
cache: bool = True,
**kwargs,
):
"""Instantiate the R-GCN encoder.
:param triples_factory:
The triples factory holding the training triples used for message passing.
:param max_id:
The maximum number of IDs. could either be None (the default), or match the triples factory's number of
entities.
:param shape:
the shape information. If None, will propagate the shape information of the base entity representations.
:param entity_representations:
the base entity representations (or a hint for them)
:param entity_representations_kwargs:
additional keyword-based parameters for the base entity representations
:param num_layers:
The number of layers.
:param use_bias:
Whether to use a bias.
:param activation:
The activation.
:param activation_kwargs:
Additional keyword based arguments passed if the activation is not pre-instantiated. Ignored otherwise.
:param edge_dropout:
The edge dropout to use. Does not apply to self-loops.
:param self_loop_dropout:
The self-loop dropout to use.
:param edge_weighting:
The edge weighting mechanism.
:param decomposition:
The decomposition, cf. :class:`pykeen.nn.message_passing.Decomposition`.
:param decomposition_kwargs:
Additional keyword based arguments passed to the decomposition upon instantiation.
:param kwargs:
additional keyword-based parameters passed to :meth:`Representation.__init__`
:param cache:
whether to cache representations
:raises ValueError: If the triples factory creates inverse triples.
"""
# input validation
if max_id and max_id != triples_factory.num_entities:
raise ValueError(
f"max_id={max_id} differs from triples_factory.num_entities={triples_factory.num_entities}"
)
if triples_factory.create_inverse_triples:
raise ValueError(
"RGCN internally creates inverse triples. It thus expects a triples factory without them.",
)
# has to be imported now to avoid cyclic imports
from . import representation_resolver
base = representation_resolver.make(
entity_representations,
max_id=triples_factory.num_entities,
pos_kwargs=entity_representations_kwargs,
)
if len(base.shape) > 1:
raise ValueError(f"{self.__class__.__name__} requires vector base entity representations.")
max_id = max_id or triples_factory.num_entities
if max_id != base.max_id:
raise ValueError(f"Inconsistent max_id={max_id} vs. base.max_id={base.max_id}")
shape = ShapeError.verify(shape=base.shape, reference=shape)
super().__init__(max_id=max_id, shape=shape, **kwargs)
# has to be assigned after call to nn.Module init
self.entity_embeddings = base
# Resolve edge weighting
self.edge_weighting = edge_weight_resolver.make(query=edge_weighting)
# dropout
self.edge_dropout = edge_dropout
self_loop_dropout = self_loop_dropout or edge_dropout
# Save graph using buffers, such that the tensors are moved together with the model
h, r, t = triples_factory.mapped_triples.t()
self.register_buffer("sources", h)
self.register_buffer("targets", t)
self.register_buffer("edge_types", r)
dim = base.shape[0]
self.layers = nn.ModuleList(
RGCNLayer(
input_dim=dim,
num_relations=triples_factory.num_relations,
output_dim=dim,
use_bias=use_bias,
# no activation on last layer
# cf. https://github.com/MichSchli/RelationPrediction/blob/c77b094fe5c17685ed138dae9ae49b304e0d8d89/code/common/model_builder.py#L275 # noqa: E501
activation=activation if i < num_layers - 1 else None,
activation_kwargs=activation_kwargs,
self_loop_dropout=self_loop_dropout,
decomposition=decomposition,
decomposition_kwargs=decomposition_kwargs,
)
for i in range(num_layers)
)
# buffering of enriched representations
self.enriched_embeddings = None
self.cache = cache
# docstr-coverage: inherited
[docs]
def post_parameter_update(self) -> None: # noqa: D102
super().post_parameter_update()
# invalidate enriched embeddings
self.enriched_embeddings = None
# docstr-coverage: inherited
[docs]
def reset_parameters(self): # noqa: D102
self.entity_embeddings.reset_parameters()
for m in self.layers:
if hasattr(m, "reset_parameters"):
m.reset_parameters()
elif any(p.requires_grad for p in m.parameters()):
logger.warning("Layers %s has parameters, but no reset_parameters.", m)
def _real_forward_all(self) -> FloatTensor:
if self.enriched_embeddings is not None:
return self.enriched_embeddings
# Bind fields
# shape: (num_entities, embedding_dim)
x = self.entity_embeddings(indices=None)
sources = self.sources
targets = self.targets
edge_types = self.edge_types
# Edge dropout: drop the same edges on all layers (only in training mode)
if self.training and self.edge_dropout is not None:
# Get random dropout mask
edge_keep_mask = torch.rand(self.sources.shape[0], device=x.device) > self.edge_dropout
# Apply to edges
sources = sources[edge_keep_mask]
targets = targets[edge_keep_mask]
edge_types = edge_types[edge_keep_mask]
# fixed edges -> pre-compute weights
if self.edge_weighting is not None and sources.numel() > 0:
edge_weights = torch.empty_like(sources, dtype=torch.float32)
for r in range(edge_types.max().item() + 1):
mask = edge_types == r
if mask.any():
edge_weights[mask] = self.edge_weighting(sources[mask], targets[mask])
else:
edge_weights = None
for layer in self.layers:
x = layer(
x=x,
source=sources,
target=targets,
edge_type=edge_types,
edge_weights=edge_weights,
)
# Cache enriched representations
if self.cache:
self.enriched_embeddings = x
return x
def _plain_forward(
self,
indices: Optional[LongTensor] = None,
) -> FloatTensor:
"""Enrich the entity embeddings of the decoder using R-GCN message propagation."""
x = self._real_forward_all()
if indices is not None:
x = x[indices]
return x