Representations
In PyKEEN, a Representation
is used to map
integer indices to numeric representations. A simple example is an
Embedding
, where the mapping is a simple
lookup. However, more advanced representation modules are available, too.
This tutorial is intended to provide a comprehensive overview of possible components. Feel free to visit the pages of the individual representations for detailed technical information.
Base
The Representation
class defines a common
interface for all representation modules.
Each representation defines a max_id
attribute.
We can pass any integer index \(i \in [0, \text{max_id})\) to a representation module
to get a numeric representation of a fixed shape shape
.
Note
To support efficient training and inference, all representations accept batches of indices of arbitrary shape, and return batches of corresponding numeric representations. The batch dimensions always precede the actual shape of the returned numerical representations.
Combinations & Adapters
PyKEEN provides a rich set of generic tools to combine and adapt representations to form new representations.
Transformed Representations
A TransformedRepresentation
adds a (learnable)
transformation to an existing representation. It can be particularly useful when we have
some fixed features for entities or relations, e.g. from a pre-trained model, or encodings
of other modalities like text or images, and want to learn a transformation on them to
make them suitable for simple interaction functions like DistMultInteraction
.
Subset Representations
A SubsetRepresentation
allows to “hide” some indices.
This can be useful e.g. if we want to share some representations between modules, while others
should be exclusive, e.g. we want to use inverse relations for a message passing phase, but no
inverses in scoring triples.
Partitioned Representations
A PartitionRepresentation
uses multiple base representations
and chooses exactly one of them for each index based on a fixed mapping.
BackfillRepresentation
implements a frequently used
special case, where we have two base representations, where one is the
“main” representation and the other is used as a backup whenever
the first one fails to provide a representation.
This is useful when we want to use features or pre-trained embeddings whenever
possible, and learn new embeddings for any entities or relations for which we have no features.
Combined Representations
CombinedRepresentation
can be used when we have multiple
sources of representations and want to combine those into a single one.
Use cases are multi-modal models, or NodePieceRepresentation
.
Embedding
An Embedding
is the simplest representation,
where the an index is mapped to a numerical representation by a simple lookup in a table.
Despite its simplicity, almost all publications on transductive link prediction
rely on embeddings to represent entities or relations.
Decomposition
Since knowledge graphs can contain a large number of entities, having independent trainable embeddings for each of them can lead to an excessive number of trainable parameters. Therefore, methods have been developed that do not learn independent representations, but rather have a set of base representations and create individual representations by combining them.
Low-Rank Factorization
A simple method to reduce the number of parameters is to use a low-rank
decomposition of the embedding matrix, as implemented in
LowRankRepresentation
. Here, each
representation is a linear combination of shared base representations.
Typically, the number of bases is chosen to be smaller than the dimension of
each base representation.
Low-rank factorization can also be seen as a special case of
CombinedRepresentation
with a restricted (but very efficient)
combination operation.
Tensor Train Factorization
TensorTrainRepresentation
uses a tensor factorization
method, which can also be interpreted as a hierarchical decomposition.
The tensor train decomposition is also known as matrix product states.
NodePiece
Another example is NodePiece, which takes inspiration
from tokenization we encounter in, e.g.. NLP, and represents each entity
as a set of tokens.
The basic implementation can be found in
TokenizationRepresentation
,
where each index is represented by a sequence of tokens, and the tokens
have their own representation.
NodePieceRepresentation
builds upon them
and uses one or more TokenizationRepresentation
with are then combined into a single representation.
Message Passing
Message passing representation modules enrich the representations of entities by aggregating the information from their graph neighborhood.
RGCN
The RGCNRepresentation
uses
RGCNLayer
to pass messages between entities.
These layers aggregate representations of neighboring entities,
which are first transformed by a relation-specific linear transformation.
CompGCN
The SingleCompGCNRepresentation
enriches representations
using CompGCNLayer
, which instead uses a more flexible composition
of entity and relation representations along each edge.
As a technical detail, since each CompGCNLayer
transforms
entity and relation representations, we must first construct a
CombinedCompGCNRepresentations
and then split its output into separate
SingleCompGCNRepresentation
for entities and relations, respectively.
PyTorch Geometric
Another way to utilize message passing is via the modules provided in pykeen.nn.pyg
,
which allow to use the message passing layers from PyTorch Geometric
to enrich base representations via message passing.
We include the following templates to easily create custom transformations:
MessagePassingRepresentation
: Base class.
SimpleMessagePassingRepresentation
: For message passing ignoring relation type information.
TypedMessagePassingRepresentation
For message passing using categorical relation type information.
FeaturizedMessagePassingRepresentation
For message passing using relation representations during message passing.
Text-based
Text-based representations use the entities’ (or relations’) labels to
derive representations. To this end,
TextRepresentation
uses a
(pre-trained) transformer model from the transformers
library to encode
the labels. Since the transformer models have been trained on huge corpora
of text, their text encodings often contain semantic information, i.e.,
labels with similar semantic meaning get similar representations. While we
can also benefit from these strong features by just initializing an
Embedding
with the vectors, e.g., using
LabelBasedInitializer
, the
TextRepresentation
include the
transformer model as part of the KGE model, and thus allow fine-tuning
the language model for the KGE task. This is beneficial, e.g., since it
allows a simple form of obtaining an inductive model, which can make
predictions for entities not seen during training.
import torch
from pykeen.datasets import get_dataset
from pykeen.models import ERModel
from pykeen.nn import TextRepresentation
from pykeen.pipeline import pipeline
dataset = get_dataset(dataset="nations")
entity_representations = TextRepresentation.from_dataset(dataset=dataset, encoder="transformer")
result = pipeline(
dataset=dataset,
model=ERModel,
model_kwargs=dict(
interaction="ermlpe",
interaction_kwargs=dict(
embedding_dim=entity_representations.shape[0],
),
entity_representations=entity_representations,
relation_representations_kwargs=dict(
shape=entity_representations.shape,
),
),
training_kwargs=dict(num_epochs=1),
)
model = result.model
We can use the label-encoder part to generate representations for unknown entities with labels. For instance, “uk” is an entity in nations, but we can also put in “united kingdom”, and get a roughly equivalent vector representations
entity_representation = model.entity_representations[0]
label_encoder = entity_representation.encoder
uk, united_kingdom = label_encoder(labels=["uk", "united kingdom"])
Thus, if we would put the resulting representations into the interaction function, we would get similar scores
# true triple from train: ['brazil', 'exports3', 'uk']
relation_representation = model.relation_representations[0]
h_repr = entity_representation.get_in_more_canonical_shape(
dim="h", indices=torch.as_tensor(dataset.entity_to_id["brazil"]).view(1)
)
r_repr = relation_representation.get_in_more_canonical_shape(
dim="r", indices=torch.as_tensor(dataset.relation_to_id["exports3"]).view(1)
)
scores = model.interaction(h=h_repr, r=r_repr, t=torch.stack([uk, united_kingdom]))
print(scores)
As a downside, this will usually substantially increase the computational cost of computing triple scores.
Wikidata
Since quite a few benchmark datasets for link prediction on knowledge graphs use
Wikidata as a source, e.g.,
CoDExSmall
or WD50KT
,
we added a convenience class WikidataTextRepresentation
that looks up labels based on Wikidata QIDs
(e.g., Q42 for Douglas Adams).
Biomedical Entities
If your dataset is labeled with compact uniform resource identifiers (e.g., CURIEs)
for biomedical entities like chemicals, proteins, diseases, and pathways, then
the BiomedicalCURIERepresentation
representation can make use of pyobo
to look up names (via CURIE) via the
pyobo.get_name()
function, then encode them using the text encoder.
All biomedical knowledge graphs in PyKEEN (at the time of adding this representation), unfortunately do not use CURIEs for referencing biomedical entities. In the future, we hope this will change.
To learn more about CURIEs, please take a look at the Bioregistry and this blog post on CURIEs.
Visual
Sometimes, we also have visual information about entities, e.g., in the form of images.
For these cases there is
VisualRepresentation
which uses an image encoder backbone
to obtain representations.
Wikidata
As for textual representations, we provide a convenience class
WikidataVisualRepresentation
for Wikidata-based datasets
that looks up labels based on Wikidata QIDs.