LowRankRepresentation

class LowRankRepresentation(*, max_id: int | None = None, shape: Sequence[int] | int | None = None, num_bases: int | None = 3, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, weight: str | Representation | type[Representation] | None = None, weight_kwargs: Mapping[str, Any] | None = None, **kwargs)[source]

Bases: Representation

Low-rank embedding factorization.

This representation reduces the number of trainable parameters by not learning independent weights for each index, but rather having shared bases for all indices and learning only the weights of the linear combination.

\[E[i] = \sum_k B[i, k] \cdot W[k]\]

This representation implements the generalized form, where both, \(B\) and \(W\) are arbitrary representations themselves.

Example usage:

"""Use the (generalized) low-rank approximation to create a mixture model representation."""

import pandas

from pykeen.datasets import get_dataset
from pykeen.models import ERModel
from pykeen.nn import LowRankRepresentation
from pykeen.nn.text.cache import WikidataTextCache
from pykeen.pipeline import pipeline
from pykeen.typing import FloatTensor

dataset = get_dataset(dataset="CoDExSmall", dataset_kwargs=dict(create_inverse_triples=True))

# set up relation representations as a mixture (~soft clustering) with 5 components
embedding_dim = 32
num_components = 5
relation_representation = LowRankRepresentation(
    max_id=dataset.num_relations,
    shape=embedding_dim,
    num_bases=num_components,
    weight_kwargs=dict(normalizer="softmax"),
)
# use DistMult interaction, and a simple embedding matrix for relations
model = ERModel[FloatTensor, FloatTensor, FloatTensor](
    triples_factory=dataset.training,
    interaction="distmult",
    entity_representations_kwargs=dict(shape=embedding_dim),
    relation_representations=relation_representation,
)
result = pipeline(dataset=dataset, model=model, training_kwargs=dict(num_epochs=20))

# keys are Wikidata IDs, which are the "labels" in CoDEx, and values
# are the concatenation of the Wikidata label + description
wikidata_id_to_label = WikidataTextCache().get_texts_dict(dataset.relation_to_id)

# use the mixture weights
relation_weights = relation_representation.weight().detach().cpu().numpy()
rows = [
    (relation_index, wikidata_id, wikidata_id_to_label[wikidata_id], component, weight)
    for wikidata_id, relation_index in dataset.relation_to_id.items()
    for component, weight in enumerate(relation_weights[relation_index])
]
df = pandas.DataFrame(data=rows, columns=["relation_index", "wikidata-id", "text", "component_index", "weight"])


# For each component, look at the relations that are most assigned to it
print(
    df.groupby(by="component_index").apply(lambda g: g.nlargest(3, columns="weight"), include_groups=False)[
        ["wikidata-id", "text", "weight"]
    ]
)

Initialize the representations.

Parameters:
  • max_id (int) – The maximum ID (exclusively). Valid Ids reach from 0 to max_id-1. If None, a pre-instantiated weight representation needs to be provided.

  • shape (tuple[int, ...]) – The shape of an individual representation. If None, a pre-instantiated base representation has to be provided.

  • num_bases (int | None) – The number of bases. More bases increase expressivity, but also increase the number of trainable parameters. If None, a pre-instantiated base representation has to be provided.

  • weight (HintOrType[Representation]) – The weight representation, or a hint thereof.

  • weight_kwargs (OptionalKwargs) – Additional keyword based arguments used to instantiate the weight representation.

  • base (HintOrType[Representation]) – The base representation, or a hint thereof.

  • base_kwargs (OptionalKwargs) – Additional keyword based arguments used to instantiate the weight representation.

  • kwargs – Additional keyword based arguments passed to Representation.

Raises:

MaxIDMismatchError – if the max_id was given explicitly and does not match the max_id of the weight representation

Note

The parameter pairs (base, base_kwargs), (weight, weight_kwargs) are used for pykeen.nn.representation_resolver

An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.

Attributes Summary

num_bases

Return the number of bases.

Methods Summary

approximate(other[, num_bases])

Construct a low-rank approximation of another representation.

Attributes Documentation

num_bases

Return the number of bases.

Methods Documentation

classmethod approximate(other: Representation, num_bases: int = 3, **kwargs) Self[source]

Construct a low-rank approximation of another representation.

Note

While this method tries to find a good approximation of the base representation, you may lose any (useful) inductive biases you had with the original one, e.g., from shared tokens in NodePieceRepresentation.

Parameters:
  • other (Representation) – The representation to approximate.

  • num_bases (int) – The number of bases. More bases increase expressivity, but also increase the number of trainable parameters.

  • kwargs – Additional keyword-based parameters passed to __init__(). Must not contain max_id nor shape, which are determined by other.

Returns:

A low-rank approximation obtained via (truncated) SVD, cf. torch.svd_lowrank().

Return type:

Self