# -*- coding: utf-8 -*-
"""Representation modules for NodePiece."""
import logging
from typing import Callable, List, NamedTuple, Optional, Sequence, Union
import torch
from class_resolver import HintOrType, OneOrManyHintOrType, OneOrManyOptionalKwargs, OptionalKwargs
from class_resolver.contrib.torch import aggregation_resolver
from .tokenization import Tokenizer, tokenizer_resolver
from ..representation import Representation
from ...triples import CoreTriplesFactory
from ...typing import MappedTriples, OneOrSequence
from ...utils import broadcast_upgrade_to_sequences
__all__ = [
"TokenizationRepresentation",
"HashDiversityInfo",
"NodePieceRepresentation",
]
logger = logging.getLogger(__name__)
[docs]class TokenizationRepresentation(Representation):
"""A module holding the result of tokenization."""
#: the token ID of the padding token
vocabulary_size: int
#: the token representations
vocabulary: Representation
#: the assigned tokens for each entity
assignment: torch.LongTensor
def __init__(
self,
assignment: torch.LongTensor,
token_representation: HintOrType[Representation] = None,
token_representation_kwargs: OptionalKwargs = None,
**kwargs,
) -> None:
"""
Initialize the tokenization.
:param assignment: shape: `(n, num_chosen_tokens)`
the token assignment.
:param token_representation: shape: `(num_total_tokens, *shape)`
the token representations
:param token_representation_kwargs:
additional keyword-based parameters
:param kwargs:
additional keyword-based parameters passed to super.__init__
:raises ValueError: if there's a mismatch between the representation size
and the vocabulary size
"""
# needs to be lazily imported to avoid cyclic imports
from .. import representation_resolver
# fill padding (nn.Embedding cannot deal with negative indices)
padding = assignment < 0
# sometimes, assignment.max() does not cover all relations (eg, inductive inference graphs
# contain a subset of training relations) - for that, the padding index is the last index of the Representation
self.vocabulary_size = (
token_representation.max_id
if isinstance(token_representation, Representation)
else assignment.max().item() + 2 # exclusive (+1) and including padding (+1)
)
assignment[padding] = self.vocabulary_size - 1 # = assignment.max().item() + 1
max_id, num_chosen_tokens = assignment.shape
# resolve token representation
token_representation = representation_resolver.make(
token_representation,
token_representation_kwargs,
max_id=self.vocabulary_size,
)
super().__init__(max_id=max_id, shape=(num_chosen_tokens,) + token_representation.shape, **kwargs)
# input validation
if token_representation.max_id < self.vocabulary_size:
raise ValueError(
f"The token representations only contain {token_representation.max_id} representations,"
f"but there are {self.vocabulary_size} tokens in use.",
)
elif token_representation.max_id > self.vocabulary_size:
logger.warning(
f"Token representations do contain more representations ({token_representation.max_id}) "
f"than tokens are used ({self.vocabulary_size}).",
)
# register as buffer
self.register_buffer(name="assignment", tensor=assignment)
# assign sub-module
self.vocabulary = token_representation
[docs] @classmethod
def from_tokenizer(
cls,
tokenizer: Tokenizer,
num_tokens: int,
mapped_triples: MappedTriples,
num_entities: int,
num_relations: int,
token_representation: HintOrType[Representation] = None,
token_representation_kwargs: OptionalKwargs = None,
**kwargs,
) -> "TokenizationRepresentation":
"""
Create a tokenization from applying a tokenizer.
:param tokenizer:
the tokenizer instance.
:param num_tokens:
the number of tokens to select for each entity.
:param token_representation:
the pre-instantiated token representations, or an EmbeddingSpecification to create them
:param token_representation_kwargs:
additional keyword-based parameters
:param mapped_triples:
the ID-based triples
:param num_entities:
the number of entities
:param num_relations:
the number of relations
:param kwargs:
additional keyword-based parameters passed to TokenizationRepresentation.__init__
:return:
A tokenization representation by applying the tokenizer
"""
# apply tokenizer
vocabulary_size, assignment = tokenizer(
mapped_triples=mapped_triples,
num_tokens=num_tokens,
num_entities=num_entities,
num_relations=num_relations,
)
return TokenizationRepresentation(
assignment=assignment,
token_representation=token_representation,
token_representation_kwargs=token_representation_kwargs,
**kwargs,
)
# docstr-coverage: inherited
# docstr-coverage: inherited
def _plain_forward(
self,
indices: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor: # noqa: D102
# get token IDs, shape: (*, num_chosen_tokens)
token_ids = self.assignment
if indices is not None:
token_ids = token_ids[indices]
# lookup token representations, shape: (*, num_chosen_tokens, *shape)
return self.vocabulary(token_ids)
[docs]class HashDiversityInfo(NamedTuple):
"""A ratio information object.
A pair `unique_per_repr, unique_total`, where `unique_per_repr` is a list with
the percentage of unique hashes for each token representation, and `unique_total`
the frequency of unique hashes when we concatenate all token representations.
"""
#: A list with ratios per representation in their creation order,
#: e.g., ``[0.58, 0.82]`` for :class:`AnchorTokenization` and :class:`RelationTokenization`
uniques_per_representation: List[float]
#: A scalar ratio of unique rows when combining all representations into one matrix, e.g. 0.95
uniques_total: float
[docs]class NodePieceRepresentation(Representation):
r"""
Basic implementation of node piece decomposition [galkin2021]_.
.. math ::
x_e = agg(\{T[t] \mid t \in tokens(e) \})
where $T$ are token representations, $tokens$ selects a fixed number of $k$ tokens for each entity, and $agg$ is
an aggregation function, which aggregates the individual token representations to a single entity representation.
.. note ::
This implementation currently only supports representation of entities by bag-of-relations.
"""
#: the token representations
token_representations: Sequence[TokenizationRepresentation]
def __init__(
self,
*,
triples_factory: CoreTriplesFactory,
token_representations: OneOrManyHintOrType[Representation] = None,
token_representations_kwargs: OneOrManyOptionalKwargs = None,
tokenizers: OneOrManyHintOrType[Tokenizer] = None,
tokenizers_kwargs: OneOrManyOptionalKwargs = None,
num_tokens: OneOrSequence[int] = 2,
aggregation: Union[None, str, Callable[[torch.FloatTensor, int], torch.FloatTensor]] = None,
max_id: Optional[int] = None,
shape: Optional[Sequence[int]] = None,
**kwargs,
):
"""
Initialize the representation.
:param triples_factory:
the triples factory
:param token_representations:
the token representation specification, or pre-instantiated representation module.
:param token_representations_kwargs:
additional keyword-based parameters
:param tokenizers:
the tokenizer to use, cf. `pykeen.nn.node_piece.tokenizer_resolver`.
:param tokenizers_kwargs:
additional keyword-based parameters passed to the tokenizer upon construction.
:param num_tokens:
the number of tokens for each entity.
:param aggregation:
aggregation of multiple token representations to a single entity representation. By default,
this uses :func:`torch.mean`. If a string is provided, the module assumes that this refers to a top-level
torch function, e.g. "mean" for :func:`torch.mean`, or "sum" for func:`torch.sum`. An aggregation can
also have trainable parameters, .e.g., ``MLP(mean(MLP(tokens)))`` (cf. DeepSets from [zaheer2017]_). In
this case, the module has to be created outside of this component.
We could also have aggregations which result in differently shapes output, e.g. a concatenation of all
token embeddings resulting in shape ``(num_tokens * d,)``. In this case, `shape` must be provided.
The aggregation takes two arguments: the (batched) tensor of token representations, in shape
``(*, num_tokens, *dt)``, and the index along which to aggregate.
:param shape:
the shape of an individual representation. Only necessary, if aggregation results in a change of dimensions.
this will only be necessary if the aggregation is an *ad hoc* function.
:param max_id:
Only pass this to check if the number of entities in the triples factories is the same
:param kwargs:
additional keyword-based parameters passed to super.__init__
:raises ValueError: if the shapes for any vocabulary entry
in all token representations are inconsistent
"""
if max_id:
assert max_id == triples_factory.num_entities
# normalize triples
mapped_triples = triples_factory.mapped_triples
if triples_factory.create_inverse_triples:
# inverse triples are created afterwards implicitly
mapped_triples = mapped_triples[mapped_triples[:, 1] < triples_factory.real_num_relations]
token_representations, token_representations_kwargs, num_tokens = broadcast_upgrade_to_sequences(
token_representations, token_representations_kwargs, num_tokens
)
# tokenize
token_representations = [
TokenizationRepresentation.from_tokenizer(
tokenizer=tokenizer_inst,
num_tokens=num_tokens_,
token_representation=token_representation,
token_representation_kwargs=token_representation_kwargs,
mapped_triples=mapped_triples,
num_entities=triples_factory.num_entities,
num_relations=triples_factory.real_num_relations,
)
for tokenizer_inst, token_representation, token_representation_kwargs, num_tokens_ in zip(
tokenizer_resolver.make_many(queries=tokenizers, kwargs=tokenizers_kwargs),
token_representations,
token_representations_kwargs,
num_tokens,
)
]
# determine shape
if shape is None:
shapes = {t.vocabulary.shape for t in token_representations}
if len(shapes) != 1:
raise ValueError(f"Inconsistent token shapes: {shapes}")
shape = list(shapes)[0]
# super init; has to happen *before* any parameter or buffer is assigned
super().__init__(max_id=triples_factory.num_entities, shape=shape, **kwargs)
# assign module
self.token_representations = torch.nn.ModuleList(token_representations)
# Assign default aggregation
self.aggregation = aggregation_resolver.lookup(aggregation)
self.aggregation_index = -(1 + len(shape))
# docstr-coverage: inherited
# docstr-coverage: inherited
def _plain_forward(
self,
indices: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor: # noqa: D102
return self.aggregation(
torch.cat(
[tokenization(indices=indices) for tokenization in self.token_representations],
dim=self.aggregation_index,
),
self.aggregation_index,
)
[docs] def estimate_diversity(self) -> HashDiversityInfo:
"""
Estimate the diversity of the tokens via their hashes.
:return:
A ratio information tuple
Tokenization strategies might produce exactly the same hashes for
several nodes depending on the graph structure and tokenization
parameters. Same hashes will result in same node representations
and, hence, might inhibit the downstream performance.
This function comes handy when you need to estimate the diversity
of built node hashes under a certain tokenization strategy - ideally,
you'd want every node to have a unique hash.
The function computes how many node hashes are unique in each
representation and overall (if we concat all of them in a single row).
1.0 means that all nodes have unique hashes.
Example usage:
.. code-block::
from pykeen.model import NodePiece
model = NodePiece(
triples_factory=dataset.training,
tokenizers=["AnchorTokenizer", "RelationTokenizer"],
num_tokens=[20, 12],
embedding_dim=64,
interaction="rotate",
relation_constrainer="complex_normalize",
entity_initializer="xavier_uniform_",
)
print(model.entity_representations[0].estimate_diversity())
.. seealso:: https://github.com/pykeen/pykeen/pull/896
"""
# unique hashes per representation
uniques_per_representation = [
tokens.assignment.unique(dim=0).shape[0] / self.max_id for tokens in self.token_representations
]
# unique hashes if we concatenate all representations together
unnormalized_uniques_total = torch.unique(
torch.cat([tokens.assignment for tokens in self.token_representations], dim=-1), dim=0
).shape[0]
return HashDiversityInfo(
uniques_per_representation=uniques_per_representation,
uniques_total=unnormalized_uniques_total / self.max_id,
)