_NewAbstractModel

class _NewAbstractModel(*, triples_factory, loss=None, loss_kwargs=None, predict_with_sigmoid=False, random_seed=None)[source]

Bases: pykeen.models.base.Model, abc.ABC

An abstract class for knowledge graph embedding models (KGEMs).

The only function that needs to be implemented for a given subclass is Model.forward(). The job of the Model.forward() function, as opposed to the completely general torch.nn.Module.forward() is to take indices for the head, relation, and tails’ respective representation(s) and to determine a score.

Subclasses of Model can decide however they want on how to store entities’ and relations’ representations, how they want to be looked up, and how they should be scored. The ERModel provides a commonly useful implementation which allows for the specification of one or more entity representations and one or more relation representations in the form of pykeen.nn.Embedding as well as a matching instance of a pykeen.nn.Interaction.

Initialize the module.

Parameters
  • triples_factory (KGInfo) – The triples factory facilitates access to the dataset.

  • loss (Union[str, Loss, Type[Loss], None]) – The loss to use. If None is given, use the loss default specific to the model subclass.

  • predict_with_sigmoid (bool) – Whether to apply sigmoid onto the scores when predicting scores. Applying sigmoid at prediction time may lead to exactly equal scores for certain triples with very high, or very low score. When not trained with applying sigmoid (or using BCEWithLogitsLoss), the scores are not calibrated to perform well with sigmoid.

  • random_seed (Optional[int]) – A random seed to use for initialising the model’s weights. Should be set when aiming at reproducibility.

Attributes Summary

can_slice_h

can_slice_r

can_slice_t

regularizer_default

The default regularizer class

regularizer_default_kwargs

The default parameters for the default regularizer class

Methods Summary

collect_regularization_term()

Get the regularization term for the loss function.

post_parameter_update()

Has to be called after each parameter update.

Attributes Documentation

can_slice_h: ClassVar[bool] = True
can_slice_r: ClassVar[bool] = True
can_slice_t: ClassVar[bool] = True
regularizer_default: ClassVar[Optional[Type[Regularizer]]] = None

The default regularizer class

regularizer_default_kwargs: ClassVar[Optional[Mapping[str, Any]]] = None

The default parameters for the default regularizer class

Methods Documentation

collect_regularization_term()[source]

Get the regularization term for the loss function.

post_parameter_update()[source]

Has to be called after each parameter update.

Return type

None