InductiveNodePiece
- class InductiveNodePiece(*, triples_factory: ~pykeen.triples.triples_factory.CoreTriplesFactory, inference_factory: ~pykeen.triples.triples_factory.CoreTriplesFactory, num_tokens: int = 2, embedding_dim: int = 64, relation_representations_kwargs: ~collections.abc.Mapping[str, ~typing.Any] | None = None, interaction: str | ~pykeen.nn.modules.Interaction | type[~pykeen.nn.modules.Interaction] | None = <class 'pykeen.nn.modules.DistMultInteraction'>, aggregation: str | ~typing.Callable[[~torch.Tensor, int], ~torch.Tensor] | None = None, validation_factory: ~pykeen.triples.triples_factory.CoreTriplesFactory | None = None, test_factory: ~pykeen.triples.triples_factory.CoreTriplesFactory | None = None, **kwargs)[source]
Bases:
InductiveERModel
A wrapper which combines an interaction function with NodePiece entity representations from [galkin2021].
This model uses the
pykeen.nn.NodePieceRepresentation
instead of a typicalpykeen.nn.Embedding
to more efficiently store representations.Initialize the model.
- Parameters:
triples_factory (CoreTriplesFactory) – the triples factory of training triples. Must have create_inverse_triples set to True.
inference_factory (CoreTriplesFactory) – the triples factory of inference triples. Must have create_inverse_triples set to True.
validation_factory (CoreTriplesFactory | None) – the triples factory of validation triples. Must have create_inverse_triples set to True.
test_factory (CoreTriplesFactory | None) – the triples factory of testing triples. Must have create_inverse_triples set to True.
num_tokens (int) – the number of relations to use to represent each entity, cf.
pykeen.nn.NodePieceRepresentation
.embedding_dim (int) – the embedding dimension. Only used if embedding_specification is not given.
relation_representations_kwargs (Mapping[str, Any] | None) – the relation representation parameters
interaction (Interaction) – the interaction module, or a hint for it.
aggregation (str | Callable[[Tensor, int], Tensor] | None) –
aggregation of multiple token representations to a single entity representation. By default, this uses
torch.mean()
. If a string is provided, the module assumes that this refers to a top-level torch function, e.g. “mean” fortorch.mean()
, or “sum” for func:torch.sum. An aggregation can also have trainable parameters, .e.g.,MLP(mean(MLP(tokens)))
(cf. DeepSets from [zaheer2017]). In this case, the module has to be created outside of this component.Moreover, we support providing “mlp” as a shortcut to use the MLP aggregation version from [galkin2021].
We could also have aggregations which result in differently shapes output, e.g. a concatenation of all token embeddings resulting in shape
(num_tokens * d,)
. In this case, shape must be provided.The aggregation takes two arguments: the (batched) tensor of token representations, in shape
(*, num_tokens, *dt)
, and the index along which to aggregate.kwargs – additional keyword-based arguments passed to
ERModel.__init__()
- Raises:
ValueError – if the triples factory does not create inverse triples
Attributes Summary
The default strategy for optimizing the model's hyper-parameters
Methods Summary
Create NodePiece representations for a new triples factory.
Attributes Documentation
- hpo_default: ClassVar[Mapping[str, Any]] = {'embedding_dim': {'high': 256, 'low': 16, 'q': 16, 'type': <class 'int'>}}
The default strategy for optimizing the model’s hyper-parameters
Methods Documentation
- create_entity_representation_for_new_triples(triples_factory: CoreTriplesFactory) NodePieceRepresentation [source]
Create NodePiece representations for a new triples factory.
The representations are initialized such that the same relation representations are used, and the aggregation is shared.
- Parameters:
triples_factory (CoreTriplesFactory) – the triples factory used for relation tokenization; must share the same relation to ID mapping.
- Returns:
a new NodePiece entity representation with shared relation tokenization and aggregation.
- Raises:
ValueError – if the triples factory does not request inverse triples, or the number of relations differs.
- Return type: