TextRepresentation

class TextRepresentation(labels: Sequence[str | None], max_id: int | None = None, shape: int | Sequence[int] | None = None, encoder: str | TextEncoder | type[TextEncoder] | None = None, encoder_kwargs: Mapping[str, Any] | None = None, missing_action: Literal['blank', 'error'] = 'error', **kwargs: Any)[source]

Bases: Representation

Textual representations using a text encoder on labels.

Example Usage:

Entity representations are obtained by encoding the labels with a Transformer model. The transformer model becomes part of the KGE model, and its parameters are trained jointly.

"""Using text representations."""

import torch

from pykeen.datasets import get_dataset
from pykeen.models import ERModel
from pykeen.nn import TextRepresentation
from pykeen.pipeline import pipeline

dataset = get_dataset(dataset="nations")
entity_representations = TextRepresentation.from_dataset(dataset=dataset, encoder="transformer")
result = pipeline(
    dataset=dataset,
    model=ERModel,
    model_kwargs=dict(
        interaction="ermlpe",
        interaction_kwargs=dict(
            embedding_dim=entity_representations.shape[0],
        ),
        entity_representations=entity_representations,
        relation_representations_kwargs=dict(
            shape=entity_representations.shape,
        ),
    ),
    training_kwargs=dict(num_epochs=1),
)
model = result.model

# representations for unseen
entity_representation = model.entity_representations[0]
label_encoder = entity_representation.encoder
uk, united_kingdom = label_encoder(labels=["uk", "united kingdom"])

# true triple from train: ['brazil', 'exports3', 'uk']
relation_representation = model.relation_representations[0]
h_repr = entity_representation.get_in_more_canonical_shape(
    dim="h", indices=torch.as_tensor(dataset.entity_to_id["brazil"]).view(1)
)
r_repr = relation_representation.get_in_more_canonical_shape(
    dim="r", indices=torch.as_tensor(dataset.relation_to_id["exports3"]).view(1)
)
scores = model.interaction(h=h_repr, r=r_repr, t=torch.stack([uk, united_kingdom]))
print(scores)

Initialize the representation.

Parameters:
  • labels (Sequence[str | None]) – An ordered, finite collection of labels.

  • max_id (int | None) – The number of representations. If provided, has to match the number of labels.

  • shape (OneOrSequence[int] | None) – The shape of an individual representation.

  • encoder (HintOrType[TextEncoder]) – The text encoder, or a hint thereof.

  • encoder_kwargs (OptionalKwargs) – Keyword-based parameters used to instantiate the text encoder.

  • missing_action (Literal['blank', 'error']) – Which policy for handling nones in the given labels. If “error”, raises an error on any nones. If “blank”, replaces nones with an empty string.

  • kwargs (Any) – Additional keyword-based parameters passed to pykeen.nn.representation.Representation

Raises:

ValueError – If the max_id does not match.

Note

The parameter pair (encoder, encoder_kwargs) is used for pykeen.nn.text.text_encoder_resolver

An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.

Methods Summary

from_dataset(dataset[, for_entities])

Prepare text representation with labels from a dataset.

from_triples_factory(triples_factory[, ...])

Prepare a text representations with labels from a triples factory.

Methods Documentation

classmethod from_dataset(dataset: Dataset, for_entities: bool = True, **kwargs) TextRepresentation[source]

Prepare text representation with labels from a dataset.

Parameters:
Returns:

A text representation from the dataset.

Raises:

TypeError – If the dataset’s triples factory does not provide labels.

Return type:

TextRepresentation

classmethod from_triples_factory(triples_factory: TriplesFactory, for_entities: bool = True, **kwargs) TextRepresentation[source]

Prepare a text representations with labels from a triples factory.

Parameters:
Returns:

a text representation from the triples factory

Return type:

TextRepresentation