# -*- coding: utf-8 -*-
"""Implementation of the R-GCN model."""
import logging
from os import path
from typing import Any, Callable, Mapping, Optional, Type
import torch
from torch import nn
from torch.nn import functional
from . import ComplEx, DistMult, ERMLP
from .. import EntityEmbeddingModel
from ..base import Model
from ...losses import Loss
from ...triples import TriplesFactory
__all__ = [
'RGCN',
]
logger = logging.getLogger(name=path.basename(__file__))
def _get_neighborhood(
start_nodes: torch.LongTensor,
sources: torch.LongTensor,
targets: torch.LongTensor,
k: int,
num_nodes: int,
undirected: bool = False,
) -> torch.BoolTensor:
# Construct node neighbourhood mask
node_mask = torch.zeros(num_nodes, device=start_nodes.device, dtype=torch.bool)
# Set nodes in batch to true
node_mask[start_nodes] = True
# Compute k-neighbourhood
for _ in range(k):
# if the target node needs an embeddings, so does the source node
node_mask[sources] |= node_mask[targets]
if undirected:
node_mask[targets] |= node_mask[sources]
# Create edge mask
edge_mask = node_mask[targets]
if undirected:
edge_mask |= node_mask[sources]
return edge_mask
# pylint: disable=unused-argument
def inverse_indegree_edge_weights(source: torch.LongTensor, target: torch.LongTensor) -> torch.FloatTensor:
"""Normalize messages by inverse in-degree.
:param source: shape: (num_edges,)
The source indices.
:param target: shape: (num_edges,)
The target indices.
:return: shape: (num_edges,)
The edge weights.
"""
# Calculate in-degree, i.e. number of incoming edges
uniq, inv, cnt = torch.unique(target, return_counts=True, return_inverse=True)
return cnt[inv].float().reciprocal()
# pylint: disable=unused-argument
def inverse_outdegree_edge_weights(source: torch.LongTensor, target: torch.LongTensor) -> torch.FloatTensor:
"""Normalize messages by inverse out-degree.
:param source: shape: (num_edges,)
The source indices.
:param target: shape: (num_edges,)
The target indices.
:return: shape: (num_edges,)
The edge weights.
"""
# Calculate in-degree, i.e. number of incoming edges
uniq, inv, cnt = torch.unique(source, return_counts=True, return_inverse=True)
return cnt[inv].float().reciprocal()
def symmetric_edge_weights(source: torch.LongTensor, target: torch.LongTensor) -> torch.FloatTensor:
"""Normalize messages by product of inverse sqrt of in-degree and out-degree.
:param source: shape: (num_edges,)
The source indices.
:param target: shape: (num_edges,)
The target indices.
:return: shape: (num_edges,)
The edge weights.
"""
return (
inverse_indegree_edge_weights(source=source, target=target)
* inverse_outdegree_edge_weights(source=source, target=target)
).sqrt()
[docs]class RGCN(Model):
"""An implementation of R-GCN from [schlichtkrull2018]_.
This model uses graph convolutions with relation-specific weights.
.. seealso::
- `Pytorch Geometric's implementation of R-GCN
<https://github.com/rusty1s/pytorch_geometric/blob/1.3.2/examples/rgcn.py>`_
- `DGL's implementation of R-GCN
<https://github.com/dmlc/dgl/tree/v0.4.0/examples/pytorch/rgcn>`_
"""
#: Interaction model used as decoder
base_model: EntityEmbeddingModel
#: The blocks of the relation-specific weight matrices
#: shape: (num_relations, num_blocks, embedding_dim//num_blocks, embedding_dim//num_blocks)
blocks: Optional[nn.ParameterList]
#: The base weight matrices to generate relation-specific weights
#: shape: (num_bases, embedding_dim, embedding_dim)
bases: Optional[nn.ParameterList]
#: The relation-specific weights for each base
#: shape: (num_relations, num_bases)
att: Optional[nn.ParameterList]
#: The biases for each layer (if used)
#: shape of each element: (embedding_dim,)
biases: Optional[nn.ParameterList]
#: Batch normalization for each layer (if used)
batch_norms: Optional[nn.ModuleList]
#: Activations for each layer (if used)
activations: Optional[nn.ModuleList]
#: The default strategy for optimizing the model's hyper-parameters
hpo_default = dict(
embedding_dim=dict(type=int, low=50, high=1000, q=50),
num_bases_or_blocks=dict(type=int, low=2, high=20, q=1),
num_layers=dict(type=int, low=1, high=5, q=1),
use_bias=dict(type='bool'),
use_batch_norm=dict(type='bool'),
activation_cls=dict(type='categorical', choices=[None, nn.ReLU, nn.LeakyReLU]),
base_model_cls=dict(type='categorical', choices=[DistMult, ComplEx, ERMLP]),
edge_dropout=dict(type=float, low=0.0, high=.9),
self_loop_dropout=dict(type=float, low=0.0, high=.9),
edge_weighting=dict(type='categorical', choices=[
None,
inverse_indegree_edge_weights,
inverse_outdegree_edge_weights,
symmetric_edge_weights,
]),
decomposition=dict(type='categorical', choices=['basis', 'block']),
)
def __init__(
self,
triples_factory: TriplesFactory,
embedding_dim: int = 500,
automatic_memory_optimization: Optional[bool] = None,
loss: Optional[Loss] = None,
predict_with_sigmoid: bool = False,
preferred_device: Optional[str] = None,
random_seed: Optional[int] = None,
num_bases_or_blocks: int = 5,
num_layers: int = 2,
use_bias: bool = True,
use_batch_norm: bool = False,
activation_cls: Optional[Type[nn.Module]] = None,
activation_kwargs: Optional[Mapping[str, Any]] = None,
base_model: Optional[Model] = None,
sparse_messages_slcwa: bool = True,
edge_dropout: float = 0.4,
self_loop_dropout: float = 0.2,
edge_weighting: Callable[
[torch.LongTensor, torch.LongTensor],
torch.FloatTensor
] = inverse_indegree_edge_weights,
decomposition: str = 'basis',
buffer_messages: bool = True,
):
super().__init__(
triples_factory=triples_factory,
automatic_memory_optimization=automatic_memory_optimization,
loss=loss,
predict_with_sigmoid=predict_with_sigmoid,
preferred_device=preferred_device,
random_seed=random_seed,
)
if self.triples_factory.create_inverse_triples:
raise ValueError('R-GCN handles edges in an undirected manner.')
if base_model is None:
# Instantiate model
base_model = DistMult(
triples_factory=triples_factory,
embedding_dim=embedding_dim,
automatic_memory_optimization=automatic_memory_optimization,
loss=loss,
preferred_device=preferred_device,
random_seed=random_seed,
)
self.base_model = base_model
self.base_embeddings = nn.Parameter(
data=torch.empty(
self.triples_factory.num_entities,
embedding_dim,
device=self.device,
),
requires_grad=True,
)
self.embedding_dim = embedding_dim
self.decomposition = decomposition
# Heuristic
if self.decomposition == 'basis':
if num_bases_or_blocks is None:
logging.info('Using a heuristic to determine the number of bases.')
num_bases_or_blocks = triples_factory.num_relations // 2 + 1
if num_bases_or_blocks > triples_factory.num_relations:
raise ValueError('The number of bases should not exceed the number of relations.')
elif self.decomposition == 'block':
if num_bases_or_blocks is None:
logging.info('Using a heuristic to determine the number of blocks.')
num_bases_or_blocks = 2
if embedding_dim % num_bases_or_blocks != 0:
raise ValueError(
'With block decomposition, the embedding dimension has to be divisible by the number of'
f' blocks, but {embedding_dim} % {num_bases_or_blocks} != 0.'
)
else:
raise ValueError(f'Unknown decomposition: "{decomposition}". Please use either "basis" or "block".')
self.num_bases = num_bases_or_blocks
# buffering of messages
self.buffer_messages = buffer_messages
self.enriched_embeddings = None
self.edge_weighting = edge_weighting
self.edge_dropout = edge_dropout
if self_loop_dropout is None:
self_loop_dropout = edge_dropout
self.self_loop_dropout = self_loop_dropout
self.use_batch_norm = use_batch_norm
if activation_cls is None:
activation_cls = nn.ReLU
self.activation_cls = activation_cls
self.activation_kwargs = activation_kwargs
if use_batch_norm:
if use_bias:
logger.warning('Disabling bias because batch normalization was used.')
use_bias = False
self.use_bias = use_bias
self.num_layers = num_layers
self.sparse_messages_slcwa = sparse_messages_slcwa
# Save graph using buffers, such that the tensors are moved together with the model
h, r, t = self.triples_factory.mapped_triples.t()
self.register_buffer('sources', h)
self.register_buffer('targets', t)
self.register_buffer('edge_types', r)
self.activations = nn.ModuleList([
self.activation_cls(**(self.activation_kwargs or {})) for _ in range(self.num_layers)
])
# Weights
self.bases = nn.ParameterList()
if self.decomposition == 'basis':
self.att = nn.ParameterList()
for _ in range(self.num_layers):
self.bases.append(
nn.Parameter(
torch.empty(
self.num_bases,
self.embedding_dim,
self.embedding_dim,
device=self.device,
), requires_grad=True)
)
self.att.append(
nn.Parameter(
torch.empty(
self.num_relations + 1,
self.num_bases,
device=self.device,
), requires_grad=True)
)
elif self.decomposition == 'block':
block_size = self.embedding_dim // self.num_bases
for _ in range(self.num_layers):
self.bases.append(
nn.Parameter(
data=torch.empty(
self.num_relations + 1,
self.num_bases,
block_size,
block_size,
device=self.device,
), requires_grad=True)
)
self.att = None
else:
raise NotImplementedError
if self.use_bias:
self.biases = nn.ParameterList([
nn.Parameter(torch.empty(self.embedding_dim, device=self.device), requires_grad=True)
for _ in range(self.num_layers)
])
else:
self.biases = None
if self.use_batch_norm:
self.batch_norms = nn.ModuleList([
nn.BatchNorm1d(num_features=self.embedding_dim)
for _ in range(self.num_layers)
])
else:
self.batch_norms = None
# Finalize initialization
self.reset_parameters_()
[docs] def post_parameter_update(self) -> None: # noqa: D102
super().post_parameter_update()
# invalidate enriched embeddings
self.enriched_embeddings = None
def _reset_parameters_(self):
self.base_model.reset_parameters_()
# https://github.com/MichSchli/RelationPrediction/blob/c77b094fe5c17685ed138dae9ae49b304e0d8d89/code/encoders/affine_transform.py#L24-L28
nn.init.xavier_uniform_(self.base_embeddings)
gain = nn.init.calculate_gain(nonlinearity=self.activation_cls.__name__.lower())
if self.decomposition == 'basis':
for base in self.bases:
nn.init.xavier_normal_(base, gain=gain)
for att in self.att:
# Random convex-combination of bases for initialization (guarantees that initial weight matrices are
# initialized properly)
# We have one additional relation for self-loops
nn.init.uniform_(att)
functional.normalize(att.data, p=1, dim=1, out=att.data)
elif self.decomposition == 'block':
for base in self.bases:
block_size = base.shape[-1]
# Xavier Glorot initialization of each block
std = torch.sqrt(torch.as_tensor(2.)) * gain / (2 * block_size)
nn.init.normal_(base, std=std)
# Reset biases
if self.biases is not None:
for bias in self.biases:
nn.init.zeros_(bias)
# Reset batch norm parameters
if self.batch_norms is not None:
for bn in self.batch_norms:
bn.reset_parameters()
# Reset activation parameters, if any
for act in self.activations:
if hasattr(act, 'reset_parameters'):
act.reset_parameters()
def _enrich_embeddings(self, batch: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
"""
Enrich the entity embeddings using R-GCN message propagation.
:return: shape: (num_entities, embedding_dim)
The updated entity embeddings
"""
# use buffered messages if applicable
if batch is None and self.enriched_embeddings is not None:
return self.enriched_embeddings
# Bind fields
# shape: (num_entities, embedding_dim)
x = self.base_embeddings
sources = self.sources
targets = self.targets
edge_types = self.edge_types
# Edge dropout: drop the same edges on all layers (only in training mode)
if self.training and self.edge_dropout is not None:
# Get random dropout mask
edge_keep_mask = torch.rand(self.sources.shape[0], device=x.device) > self.edge_dropout
# Apply to edges
sources = sources[edge_keep_mask]
targets = targets[edge_keep_mask]
edge_types = edge_types[edge_keep_mask]
# Different dropout for self-loops (only in training mode)
if self.training and self.self_loop_dropout is not None:
node_keep_mask = torch.rand(self.num_entities, device=x.device) > self.self_loop_dropout
else:
node_keep_mask = None
# If batch is given, compute (num_layers)-hop neighbourhood
if batch is not None:
start_nodes = torch.cat([batch[:, 0], batch[:, 2]], dim=0)
edge_mask = _get_neighborhood(
start_nodes=start_nodes,
sources=sources,
targets=targets,
k=self.num_layers,
num_nodes=self.num_entities,
undirected=True,
)
else:
edge_mask = None
for i in range(self.num_layers):
# Initialize embeddings in the next layer for all nodes
new_x = torch.zeros_like(x)
# TODO: Can we vectorize this loop?
for r in range(self.num_relations):
# Choose the edges which are of the specific relation
mask = (edge_types == r)
# Only propagate messages on subset of edges
if edge_mask is not None:
mask &= edge_mask
# No edges available? Skip rest of inner loop
if not mask.any():
continue
# Get source and target node indices
sources_r = sources[mask]
targets_r = targets[mask]
# send messages in both directions
sources_r, targets_r = torch.cat([sources_r, targets_r]), torch.cat([targets_r, sources_r])
# Select source node embeddings
x_s = x[sources_r]
# get relation weights
w = self._get_relation_weights(i_layer=i, r=r)
# Compute message (b x d) * (d x d) = (b x d)
m_r = x_s @ w
# Normalize messages by relation-specific in-degree
if self.edge_weighting is not None:
m_r *= self.edge_weighting(source=sources_r, target=targets_r).unsqueeze(dim=-1)
# Aggregate messages in target
new_x.index_add_(dim=0, index=targets_r, source=m_r)
# Self-loop
self_w = self._get_relation_weights(i_layer=i, r=self.num_relations)
if node_keep_mask is None:
new_x += x @ self_w
else:
new_x[node_keep_mask] += x[node_keep_mask] @ self_w
# Apply bias, if requested
if self.use_bias:
bias = self.biases[i]
new_x += bias
# Apply batch normalization, if requested
if self.use_batch_norm:
batch_norm = self.batch_norms[i]
new_x = batch_norm(new_x)
# Apply non-linearity
if self.activations is not None:
activation = self.activations[i]
new_x = activation(new_x)
x = new_x
if batch is None and self.buffer_messages:
self.enriched_embeddings = x
return x
def _get_relation_weights(self, i_layer: int, r: int) -> torch.FloatTensor:
if self.decomposition == 'block':
# allocate weight
w = torch.zeros(self.embedding_dim, self.embedding_dim, device=self.device)
# Get blocks
this_layer_blocks = self.bases[i_layer]
# self.bases[i_layer].shape (num_relations, num_blocks, embedding_dim/num_blocks, embedding_dim/num_blocks)
# note: embedding_dim is guaranteed to be divisible by num_bases in the constructor
block_size = self.embedding_dim // self.num_bases
for b, start in enumerate(range(0, self.embedding_dim, block_size)):
stop = start + block_size
w[start:stop, start:stop] = this_layer_blocks[r, b, :, :]
elif self.decomposition == 'basis':
# The current basis weights, shape: (num_bases)
att = self.att[i_layer][r, :]
# the current bases, shape: (num_bases, embedding_dim, embedding_dim)
b = self.bases[i_layer]
# compute the current relation weights, shape: (embedding_dim, embedding_dim)
w = torch.sum(att[:, None, None] * b, dim=0)
else:
raise AssertionError(f'Unknown decomposition: {self.decomposition}')
return w
[docs] def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
# Enrich embeddings
self.base_model.entity_embeddings.weight.data = self._enrich_embeddings(batch=None)
return self.base_model.score_hrt(hrt_batch=hrt_batch)