ERModel
- class ERModel(*, triples_factory: KGInfo, interaction: str | Interaction[HeadRepresentation, RelationRepresentation, TailRepresentation] | type[Interaction[HeadRepresentation, RelationRepresentation, TailRepresentation]] | None, interaction_kwargs: Mapping[str, Any] | None = None, entity_representations: str | Representation | type[Representation] | None | Sequence[str | Representation | type[Representation] | None] = None, entity_representations_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, relation_representations: str | Representation | type[Representation] | None | Sequence[str | Representation | type[Representation] | None] = None, relation_representations_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, skip_checks: bool = False, **kwargs)[source]
Bases:
Generic
[HeadRepresentation
,RelationRepresentation
,TailRepresentation
],_NewAbstractModel
A commonly useful base for KGEMs using embeddings and interaction modules.
This model does not use post-init hooks to automatically initialize all of its parameters. Rather, the call to
Model.reset_parameters_()
happens at the end ofERModel.__init__
. This is possible because all trainable parameters should necessarily be passed through thesuper().__init__()
in subclasses ofERModel
.Other code can still be put after the call to
super().__init__()
in subclasses, such as registering regularizers (as done inpykeen.models.ConvKB
andpykeen.models.TransH
).Initialize the module.
- Parameters:
triples_factory (KGInfo) – The triples factory facilitates access to the dataset.
interaction (Interaction) – The interaction module (e.g., TransE)
interaction_kwargs (OptionalKwargs) – Additional key-word based parameters given to the interaction module’s constructor, if not already instantiated.
entity_representations (Sequence[Representation]) – The entity representation or sequence of representations
entity_representations_kwargs (OneOrManyOptionalKwargs) – additional keyword-based parameters for instantiation of entity representations
relation_representations (Sequence[Representation]) – The relation representation or sequence of representations
relation_representations_kwargs (OneOrManyOptionalKwargs) – additional keyword-based parameters for instantiation of relation representations
skip_checks (bool) – whether to skip entity representation checks.
kwargs – Keyword arguments to pass to the base model
Methods Summary
append_weight_regularizer
(parameter, regularizer)Add a model weight to a regularizer's weight list, and register the regularizer with the model.
forward
(h_indices, r_indices, t_indices[, ...])Forward pass.
score_h
(rt_batch, *[, slice_size, mode, heads])Forward pass using left side (head) prediction.
score_hrt
(hrt_batch, *[, mode])Forward pass.
score_r
(ht_batch, *[, slice_size, mode, ...])Forward pass using middle (relation) prediction.
score_t
(hr_batch, *[, slice_size, mode, tails])Forward pass using right side (tail) prediction.
Methods Documentation
- append_weight_regularizer(parameter: str | Parameter | Iterable[str | Parameter], regularizer: str | Regularizer | type[Regularizer] | None, regularizer_kwargs: Mapping[str, Any] | None = None, default_regularizer: str | Regularizer | type[Regularizer] | None = None, default_regularizer_kwargs: Mapping[str, Any] | None = None) None [source]
Add a model weight to a regularizer’s weight list, and register the regularizer with the model.
- Parameters:
parameter (str | Parameter | Iterable[str | Parameter]) –
- The parameter, either as name, or as nn.Parameter object. A list of available parameter names is shown by
sorted(dict(self.named_parameters()).keys()).
regularizer (str | Regularizer | type[Regularizer] | None) – the regularizer or a hint thereof
regularizer_kwargs (Mapping[str, Any] | None) – additional keyword-based parameters for the regularizer’s instantiation
default_regularizer (str | Regularizer | type[Regularizer] | None) – the default regularizer; if None, use
regularizer_default
default_regularizer_kwargs (Mapping[str, Any] | None) – the default regularizer kwargs; if None, use
regularizer_default_kwargs
- Raises:
KeyError – If an invalid parameter name was given
- Return type:
None
- forward(h_indices: Tensor, r_indices: Tensor, t_indices: Tensor, slice_size: int | None = None, slice_dim: int = 0, *, mode: Literal['training', 'validation', 'testing'] | None) Tensor [source]
Forward pass.
This method takes head, relation and tail indices and calculates the corresponding scores. It supports broadcasting.
- Parameters:
h_indices (Tensor) – The head indices.
r_indices (Tensor) – The relation indices.
t_indices (Tensor) – The tail indices.
slice_size (int | None) – The slice size.
slice_dim (int) – The dimension along which to slice
mode (Literal['training', 'validation', 'testing'] | None) – The pass mode, which is None in the transductive setting and one of “training”, “validation”, or “testing” in the inductive setting.
- Returns:
The scores
- Raises:
NotImplementedError – if score repetition becomes necessary
- Return type:
- score_h(rt_batch: Tensor, *, slice_size: int | None = None, mode: Literal['training', 'validation', 'testing'] | None = None, heads: Tensor | None = None) Tensor [source]
Forward pass using left side (head) prediction.
This method calculates the score for all possible heads for each (relation, tail) pair.
- Parameters:
rt_batch (Tensor) – shape: (batch_size, 2), dtype: long The indices of (relation, tail) pairs.
slice_size (int | None) – >0 The divisor for the scoring function when using slicing.
mode (Literal['training', 'validation', 'testing'] | None) – The pass mode, which is None in the transductive setting and one of “training”, “validation”, or “testing” in the inductive setting.
heads (Tensor | None) – shape: (num_heads,) | (batch_size, num_heads) head entity indices to score against. If None, scores against all entities (from the given mode).
- Returns:
shape: (batch_size, num_heads), dtype: float For each r-t pair, the scores for all possible heads.
- Return type:
- score_hrt(hrt_batch: Tensor, *, mode: Literal['training', 'validation', 'testing'] | None = None) Tensor [source]
Forward pass.
This method takes head, relation and tail of each triple and calculates the corresponding score.
- Parameters:
- Returns:
shape: (batch_size, 1), dtype: float The score for each triple.
- Return type:
- score_r(ht_batch: Tensor, *, slice_size: int | None = None, mode: Literal['training', 'validation', 'testing'] | None = None, relations: Tensor | None = None) Tensor [source]
Forward pass using middle (relation) prediction.
This method calculates the score for all possible relations for each (head, tail) pair.
- Parameters:
ht_batch (Tensor) – shape: (batch_size, 2), dtype: long The indices of (head, tail) pairs.
slice_size (int | None) – >0 The divisor for the scoring function when using slicing.
mode (Literal['training', 'validation', 'testing'] | None) – The pass mode, which is None in the transductive setting and one of “training”, “validation”, or “testing” in the inductive setting.
relations (Tensor | None) – shape: (num_relations,) | (batch_size, num_relations) relation indices to score against. If None, scores against all relations (from the given mode).
- Returns:
shape: (batch_size, num_real_relations), dtype: float For each h-t pair, the scores for all possible relations.
- Return type:
- score_t(hr_batch: Tensor, *, slice_size: int | None = None, mode: Literal['training', 'validation', 'testing'] | None = None, tails: Tensor | None = None) Tensor [source]
Forward pass using right side (tail) prediction.
This method calculates the score for all possible tails for each (head, relation) pair.
- Parameters:
hr_batch (Tensor) – shape: (batch_size, 2), dtype: long The indices of (head, relation) pairs.
slice_size (int | None) – >0 The divisor for the scoring function when using slicing.
mode (Literal['training', 'validation', 'testing'] | None) – The pass mode, which is None in the transductive setting and one of “training”, “validation”, or “testing” in the inductive setting.
tails (Tensor | None) – shape: (num_tails,) | (batch_size, num_tails) tail entity indices to score against. If None, scores against all entities (from the given mode).
- Returns:
shape: (batch_size, num_tails), dtype: float For each h-r pair, the scores for all possible tails.
- Return type: