Source code for pykeen.nn.message_passing

# -*- coding: utf-8 -*-

"""Various decompositions for R-GCN."""

import logging
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Union

import torch
from class_resolver import Resolver
from torch import nn
from torch.nn import functional

__all__ = [
    "Decomposition",
    'BasesDecomposition',
    'BlockDecomposition',
    "decomposition_resolver",
]


def _reduce_relation_specific(
    relation: int,
    source: torch.LongTensor,
    target: torch.LongTensor,
    edge_type: torch.LongTensor,
    edge_weights: Optional[torch.FloatTensor],
) -> Union[Tuple[torch.LongTensor, torch.LongTensor, Optional[torch.FloatTensor]], Tuple[None, None, None]]:
    """Reduce edge information to one relation.

    :param relation:
        The relation ID.
    :param source: shape: (num_edges,)
        The source node IDs.
    :param target: shape: (num_edges,)
        The target node IDs.
    :param edge_type: shape: (num_edges,)
        The edge types.
    :param edge_weights: shape: (num_edges,)
        The edge weights.

    :return:
        The source, target, weights for edges related to the desired relation type.
    """
    # mask, shape: (num_edges,)
    edge_mask = edge_type == relation
    if not edge_mask.any():
        return None, None, None

    source_r = source[edge_mask]
    target_r = target[edge_mask]
    if edge_weights is not None:
        edge_weights = edge_weights[edge_mask]

    # bi-directional message passing
    source_r, target_r = torch.cat([source_r, target_r]), torch.cat([target_r, source_r])
    if edge_weights is not None:
        edge_weights = torch.cat([edge_weights, edge_weights])

    return source_r, target_r, edge_weights


[docs]class Decomposition(nn.Module, ABC): """Base module for relation-specific message passing.""" def __init__( self, input_dim: int, num_relations: int, output_dim: Optional[int] = None, ): """Initialize the layer. :param input_dim: >0 The input dimension. :param num_relations: >0 The number of relations. :param output_dim: >0 The output dimension. If None is given, defaults to input_dim. """ super().__init__() self.input_dim = input_dim self.num_relations = num_relations if output_dim is None: output_dim = input_dim self.output_dim = output_dim
[docs] @abstractmethod def forward( self, x: torch.FloatTensor, node_keep_mask: Optional[torch.BoolTensor], source: torch.LongTensor, target: torch.LongTensor, edge_type: torch.LongTensor, edge_weights: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """Relation-specific message passing from source to target. :param x: shape: (num_nodes, input_dim) The node representations. :param node_keep_mask: shape: (num_nodes,) The node-keep mask for self-loop dropout. :param source: shape: (num_edges,) The source indices. :param target: shape: (num_edges,) The target indices. :param edge_type: shape: (num_edges,) The edge types. :param edge_weights: shape: (num_edges,) Precomputed edge weights. :return: shape: (num_nodes, output_dim) The enriched node embeddings. """ raise NotImplementedError
[docs] @abstractmethod def reset_parameters(self): """Reset the parameters of this layer.""" raise NotImplementedError
[docs]class BasesDecomposition(Decomposition): """Represent relation-weights as a linear combination of base transformation matrices.""" def __init__( self, input_dim: int, num_relations: int, num_bases: Optional[int] = None, output_dim: Optional[int] = None, memory_intense: bool = False, ): """Initialize the layer. :param input_dim: >0 The input dimension. :param num_relations: >0 The number of relations. :param num_bases: >0 The number of bases to use. :param output_dim: >0 The output dimension. If None is given, defaults to input_dim. :param memory_intense: Enable memory-intense forward pass which may be faster, in particular if the number of different relations is small. """ super().__init__( input_dim=input_dim, num_relations=num_relations, output_dim=output_dim, ) # Heuristic for default value if num_bases is None: logging.info('No num_bases was provided. Falling back to 2.') num_bases = 2 if num_bases > num_relations: raise ValueError('The number of bases should not exceed the number of relations.') # weights self.bases = nn.Parameter( torch.empty( num_bases, self.input_dim, self.output_dim, ), requires_grad=True) self.relation_base_weights = nn.Parameter( torch.empty( num_relations + 1, num_bases, ), requires_grad=True) self.memory_intense = memory_intense
[docs] def reset_parameters(self): # noqa: D102 nn.init.xavier_normal_(self.bases) # Random convex-combination of bases for initialization (guarantees that initial weight matrices are # initialized properly) # We have one additional relation for self-loops nn.init.uniform_(self.relation_base_weights) functional.normalize(self.relation_base_weights.data, p=1, dim=1, out=self.relation_base_weights.data)
def _get_weight(self, relation_id: int) -> torch.FloatTensor: """Construct weight matrix for a specific relation ID. :param relation_id: The relation ID. :return: A 2-D matrix. """ return torch.einsum('bij,b->ij', self.bases, self.relation_base_weights[relation_id]) def _forward_memory_intense( self, x: torch.FloatTensor, source: torch.LongTensor, target: torch.LongTensor, edge_type: torch.LongTensor, out: torch.FloatTensor, edge_weights: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: # other relations m = torch.einsum( 'mi,mb,bij->mj', x.index_select(dim=0, index=source), self.relation_base_weights.index_select(dim=0, index=edge_type), self.bases, ) if edge_weights is not None: m = m * edge_weights.unsqueeze(dim=-1) return out.index_add(dim=0, index=target, source=m) def _forward_memory_light( self, x: torch.FloatTensor, source: torch.LongTensor, target: torch.LongTensor, edge_type: torch.LongTensor, out: torch.FloatTensor, edge_weights: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: # other relations for r in range(self.num_relations): # Select source and target indices as well as edge weights for the # currently considered relation source_r, target_r, weights_r = _reduce_relation_specific( relation=r, source=source, target=target, edge_type=edge_type, edge_weights=edge_weights, ) # skip relations without edges if source_r is None: continue # compute message, shape: (num_edges_of_type, output_dim) w = self._get_weight(relation_id=r) # since we may have one node ID appearing multiple times as source # ID, we can save some computation by first reducing to the unique # source IDs, compute transformed representations and afterwards # select these representations for the correct edges. uniq_source_r, inv_source_r = source_r.unique(return_inverse=True) # select unique source node representations m = x[uniq_source_r] # transform representations by relation specific weight m = m @ w # select the uniquely transformed representations for each edge m = m.index_select(dim=0, index=inv_source_r) # optional message weighting if weights_r is not None: m = m * weights_r.unsqueeze(dim=-1) # message aggregation out = out.index_add(dim=0, index=target_r, source=m) return out
[docs] def forward( self, x: torch.FloatTensor, node_keep_mask: Optional[torch.BoolTensor], source: torch.LongTensor, target: torch.LongTensor, edge_type: torch.LongTensor, edge_weights: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: # noqa: D102 # self-loops first # the last relation_id refers to the self-loop w = self._get_weight(relation_id=self.num_relations) if node_keep_mask is not None: assert node_keep_mask.shape == x.shape[:1] out = torch.empty_like(x) out[node_keep_mask] = x[node_keep_mask] @ w out[~node_keep_mask] = 0.0 else: out = x @ w if self.memory_intense: _forward = self._forward_memory_intense else: _forward = self._forward_memory_light return _forward( x=x, source=source, target=target, edge_type=edge_type, out=out, edge_weights=edge_weights, )
[docs]class BlockDecomposition(Decomposition): """Represent relation-specific weight matrices via block-diagonal matrices.""" def __init__( self, input_dim: int, num_relations: int, num_blocks: Optional[int] = None, output_dim: Optional[int] = None, ): """Initialize the layer. :param input_dim: >0 The input dimension. :param num_relations: >0 The number of relations. :param num_blocks: >0 The number of blocks to use. Has to be a divisor of input_dim. :param output_dim: >0 The output dimension. If None is given, defaults to input_dim. """ super().__init__( input_dim=input_dim, num_relations=num_relations, output_dim=output_dim, ) if num_blocks is None: logging.info('Using a heuristic to determine the number of blocks.') num_blocks = min(i for i in range(2, input_dim + 1) if input_dim % i == 0) block_size, remainder = divmod(input_dim, num_blocks) if remainder != 0: raise NotImplementedError( 'With block decomposition, the embedding dimension has to be divisible by the number of' f' blocks, but {input_dim} % {num_blocks} != 0.', ) self.blocks = nn.Parameter( data=torch.empty( num_relations + 1, num_blocks, block_size, block_size, ), requires_grad=True) self.num_blocks = num_blocks self.block_size = block_size
[docs] def reset_parameters(self): # noqa: D102 block_size = self.blocks.shape[-1] # Xavier Glorot initialization of each block std = torch.sqrt(torch.as_tensor(2.)) / (2 * block_size) nn.init.normal_(self.blocks, std=std)
[docs] def forward( self, x: torch.FloatTensor, node_keep_mask: Optional[torch.BoolTensor], source: torch.LongTensor, target: torch.LongTensor, edge_type: torch.LongTensor, edge_weights: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: # noqa: D102 # view as blocks x = x.view(-1, self.num_blocks, self.block_size) # self-loop first out = torch.zeros_like(x) w = self.blocks[-1] if node_keep_mask is not None: out[node_keep_mask] = torch.einsum('nbi,bij->nbj', x[node_keep_mask], w) else: out = torch.einsum('nbi,bij->nbj', x, w) # other relations for r in range(self.num_relations): source_r, target_r, weights_r = _reduce_relation_specific( relation=r, source=source, target=target, edge_type=edge_type, edge_weights=edge_weights, ) # skip relations without edges if source_r is None: continue # compute message, shape: (num_edges_of_type, num_blocks, block_size) uniq_source_r, inv_source_r = source_r.unique(return_inverse=True) w_r = self.blocks[r] m = torch.einsum('nbi,bij->nbj', x[uniq_source_r], w_r).index_select(dim=0, index=inv_source_r) # optional message weighting if weights_r is not None: m = m * weights_r.unsqueeze(dim=1).unsqueeze(dim=2) # message aggregation out.index_add_(dim=0, index=target_r, source=m) return out.reshape(-1, self.output_dim)
decomposition_resolver = Resolver.from_subclasses(base=Decomposition, default=BasesDecomposition)