LabelBasedInitializer
- class LabelBasedInitializer(labels: Sequence[str], encoder: str | TextEncoder | type[TextEncoder] | None = None, encoder_kwargs: Mapping[str, Any] | None = None, batch_size: int | None = None)[source]
Bases:
PretrainedInitializer
An initializer using pretrained models from the transformers library to encode labels.
Example Usage:
Initialize entity representations as Transformer encodings of their labels. Afterwards, the parameters are detached from the labels, and trained on the KGE task without any further connection to the Transformer model.
from pykeen.datasets import get_dataset from pykeen.nn.init import LabelBasedInitializer from pykeen.models import ERMLPE dataset = get_dataset(dataset="nations") entity_initializer = LabelBasedInitializer.from_triples_factory( triples_factory=dataset.training, encoder="transformer", ) model = ERMLPE( triples_factory=dataset.training, embedding_dim=entity_initializer.tensor.shape[-1], # 768 for BERT base entity_initializer=entity_initializer, # note: we explicitly need to provide a relation initializer here, # since ERMLPE shares initializers between entities and relations by default relation_initializer="uniform", )
Initialize the initializer.
- Parameters:
encoder (str | TextEncoder | type[TextEncoder] | None) – the text encoder to use, cf. text_encoder_resolver
encoder_kwargs (Mapping[str, Any] | None) – additional keyword-based parameters passed to the encoder
batch_size (int | None) – >0 the (maximum) batch size to use while encoding. If None, use len(labels), i.e., only a single batch.
Methods Summary
from_triples_factory
(triples_factory[, ...])Prepare a label-based initializer with labels from a triples factory.
Methods Documentation
- classmethod from_triples_factory(triples_factory: TriplesFactory, for_entities: bool = True, **kwargs) LabelBasedInitializer [source]
Prepare a label-based initializer with labels from a triples factory.
- Parameters:
triples_factory (TriplesFactory) – the triples factory
for_entities (bool) – whether to create the initializer for entities (or relations)
kwargs – additional keyword-based arguments passed to
LabelBasedInitializer.__init__()
- Returns:
A label-based initializer
- Raises:
ImportError – if the transformers library could not be imported
- Return type: