LowRankRepresentation
- class LowRankRepresentation(*, max_id: int | None = None, shape: Sequence[int] | int | None = None, num_bases: int | None = 3, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, weight: str | Representation | type[Representation] | None = None, weight_kwargs: Mapping[str, Any] | None = None, **kwargs)[source]
Bases:
Representation
Low-rank embedding factorization.
This representation reduces the number of trainable parameters by not learning independent weights for each index, but rather having shared bases for all indices and learning only the weights of the linear combination.
\[E[i] = \sum_k B[i, k] \cdot W[k]\]This representation implements the generalized form, where both, \(B\) and \(W\) are arbitrary representations themselves.
Example usage:
"""Use the (generalized) low-rank approximation to create a mixture model representation.""" import pandas from pykeen.datasets import get_dataset from pykeen.models import ERModel from pykeen.nn import LowRankRepresentation from pykeen.nn.text.cache import WikidataTextCache from pykeen.pipeline import pipeline from pykeen.typing import FloatTensor dataset = get_dataset(dataset="CoDExSmall", dataset_kwargs=dict(create_inverse_triples=True)) # set up relation representations as a mixture (~soft clustering) with 5 components embedding_dim = 32 num_components = 5 relation_representation = LowRankRepresentation( max_id=dataset.num_relations, shape=embedding_dim, num_bases=num_components, weight_kwargs=dict(normalizer="softmax"), ) # use DistMult interaction, and a simple embedding matrix for relations model = ERModel[FloatTensor, FloatTensor, FloatTensor]( triples_factory=dataset.training, interaction="distmult", entity_representations_kwargs=dict(shape=embedding_dim), relation_representations=relation_representation, ) result = pipeline(dataset=dataset, model=model, training_kwargs=dict(num_epochs=20)) # keys are Wikidata IDs, which are the "labels" in CoDEx, and values # are the concatenation of the Wikidata label + description wikidata_id_to_label = WikidataTextCache().get_texts_dict(dataset.relation_to_id) # use the mixture weights relation_weights = relation_representation.weight().detach().cpu().numpy() rows = [ (relation_index, wikidata_id, wikidata_id_to_label[wikidata_id], component, weight) for wikidata_id, relation_index in dataset.relation_to_id.items() for component, weight in enumerate(relation_weights[relation_index]) ] df = pandas.DataFrame(data=rows, columns=["relation_index", "wikidata-id", "text", "component_index", "weight"]) # For each component, look at the relations that are most assigned to it print( df.groupby(by="component_index").apply(lambda g: g.nlargest(3, columns="weight"), include_groups=False)[ ["wikidata-id", "text", "weight"] ] )
Initialize the representations.
- Parameters:
max_id (int) – The maximum ID (exclusively). Valid Ids reach from
0
tomax_id-1
. If None, a pre-instantiated weight representation needs to be provided.shape (tuple[int, ...]) – The shape of an individual representation. If None, a pre-instantiated base representation has to be provided.
num_bases (int | None) – The number of bases. More bases increase expressivity, but also increase the number of trainable parameters. If None, a pre-instantiated base representation has to be provided.
weight (HintOrType[Representation]) – The weight representation, or a hint thereof.
weight_kwargs (OptionalKwargs) – Additional keyword based arguments used to instantiate the weight representation.
base (HintOrType[Representation]) – The base representation, or a hint thereof.
base_kwargs (OptionalKwargs) – Additional keyword based arguments used to instantiate the weight representation.
kwargs – Additional keyword based arguments passed to
Representation
.
- Raises:
MaxIDMismatchError – if the
max_id
was given explicitly and does not match themax_id
of the weight representation
Note
The parameter pairs
(base, base_kwargs)
,(weight, weight_kwargs)
are used forpykeen.nn.representation_resolver
An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.
Attributes Summary
Return the number of bases.
Methods Summary
approximate
(other[, num_bases])Construct a low-rank approximation of another representation.
Attributes Documentation
- num_bases
Return the number of bases.
Methods Documentation
- classmethod approximate(other: Representation, num_bases: int = 3, **kwargs) Self [source]
Construct a low-rank approximation of another representation.
Note
While this method tries to find a good approximation of the base representation, you may lose any (useful) inductive biases you had with the original one, e.g., from shared tokens in
NodePieceRepresentation
.- Parameters:
other (Representation) – The representation to approximate.
num_bases (int) – The number of bases. More bases increase expressivity, but also increase the number of trainable parameters.
kwargs – Additional keyword-based parameters passed to
__init__()
. Must not containmax_id
norshape
, which are determined byother
.
- Returns:
A low-rank approximation obtained via (truncated) SVD, cf.
torch.svd_lowrank()
.- Return type: