ERModel
- class ERModel(*, triples_factory, interaction, interaction_kwargs=None, entity_representations=None, relation_representations=None, loss=None, loss_kwargs=None, predict_with_sigmoid=False, preferred_device=None, random_seed=None, skip_checks=False)[source]
Bases:
Generic
[pykeen.typing.HeadRepresentation
,pykeen.typing.RelationRepresentation
,pykeen.typing.TailRepresentation
],pykeen.models.nbase._NewAbstractModel
A commonly useful base for KGEMs using embeddings and interaction modules.
This model does not use post-init hooks to automatically initialize all of its parameters. Rather, the call to
Model.reset_parameters_()
happens at the end ofERModel.__init__
. This is possible because all trainable parameters should necessarily be passed through thesuper().__init__()
in subclasses ofERModel
.Other code can still be put after the call to
super().__init__()
in subclasses, such as registering regularizers (as done inpykeen.models.ConvKB
andpykeen.models.TransH
).Initialize the module.
- Parameters
triples_factory (
CoreTriplesFactory
) – The triples factory facilitates access to the dataset.interaction (
Union
[str
,Interaction
[~HeadRepresentation, ~RelationRepresentation, ~TailRepresentation],Type
[Interaction
[~HeadRepresentation, ~RelationRepresentation, ~TailRepresentation]]]) – The interaction module (e.g., TransE)interaction_kwargs (
Optional
[Mapping
[str
,Any
]]) – Additional key-word based parameters given to the interaction module’s constructor, if not already instantiated.entity_representations (
Union
[None
,EmbeddingSpecification
,RepresentationModule
,Sequence
[Union
[EmbeddingSpecification
,RepresentationModule
]]]) – The entity representation or sequence of representationsrelation_representations (
Union
[None
,EmbeddingSpecification
,RepresentationModule
,Sequence
[Union
[EmbeddingSpecification
,RepresentationModule
]]]) – The relation representation or sequence of representationsloss (
Union
[str
,Loss
,Type
[Loss
],None
]) – The loss to use. If None is given, use the loss default specific to the model subclass.loss_kwargs (
Optional
[Mapping
[str
,Any
]]) – Additional key-word based parameters given to the loss module’s constructor, if not already instantiated.predict_with_sigmoid (
bool
) – Whether to apply sigmoid onto the scores when predicting scores. Applying sigmoid at prediction time may lead to exactly equal scores for certain triples with very high, or very low score. When not trained with applying sigmoid (or using BCEWithLogitsLoss), the scores are not calibrated to perform well with sigmoid.preferred_device (
Union
[str
,device
,None
]) – The preferred device for model training and inference.random_seed (
Optional
[int
]) – A random seed to use for initialising the model’s weights. Should be set when aiming at reproducibility.skip_checks (
bool
) – whether to skip entity representation checks.
Methods Summary
append_weight_regularizer
(parameter, regularizer)Add a model weight to a regularizer's weight list, and register the regularizer with the model.
forward
(h_indices, r_indices, t_indices[, ...])Forward pass.
Methods Documentation
- append_weight_regularizer(parameter, regularizer)[source]
Add a model weight to a regularizer’s weight list, and register the regularizer with the model.
- Parameters
parameter (
Union
[str
,Parameter
,Iterable
[Union
[str
,Parameter
]]]) –- The parameter, either as name, or as nn.Parameter object. A list of available parameter names is shown by
sorted(dict(self.named_parameters()).keys()).
regularizer (
Regularizer
) – The regularizer instance which will regularize the weights.
- Raises
KeyError – If an invalid parameter name was given
- Return type
- forward(h_indices, r_indices, t_indices, slice_size=None, slice_dim=None)[source]
Forward pass.
This method takes head, relation and tail indices and calculates the corresponding score.
All indices which are not None, have to be either 1-element or have the same shape, which is the batch size.
- Parameters
h_indices (
Optional
[LongTensor
]) – The head indices. None indicates to use all.r_indices (
Optional
[LongTensor
]) – The relation indices. None indicates to use all.t_indices (
Optional
[LongTensor
]) – The tail indices. None indicates to use all.slice_dim (
Optional
[str
]) – The dimension along which to slice. From {“h”, “r”, “t”}
- Return type
FloatTensor
- Returns
shape: (batch_size, num_heads, num_relations, num_tails) The score for each triple.