get_head_prediction_df
- get_head_prediction_df(model, relation_label, tail_label, *, triples_factory, add_novelties=True, remove_known=False, testing=None, mode=None)[source]
Predict heads for the given relation and tail (given by label).
- Parameters
model (
Model
) – A PyKEEN modelrelation_label (
str
) – The string label for the relationtail_label (
str
) – The string label for the tail entitytriples_factory (
TriplesFactory
) – Training triples factoryadd_novelties (
bool
) – Should the dataframe include a column denoting if the ranked head entities correspond to novel triples?remove_known (
bool
) – Should non-novel triples (those appearing in the training set) be shown with the results? On one hand, this allows you to better assess the goodness of the predictions - you want to see that the non-novel triples generally have higher scores. On the other hand, if you’re doing hypothesis generation, they may pose as a distraction. If this is set to True, then non-novel triples will be removed and the column denoting novelty will be excluded, since all remaining triples will be novel. Defaults to false.testing (
Optional
[LongTensor
]) – The mapped_triples from the testing triples factory (TriplesFactory.mapped_triples)mode (
Optional
[Literal
[‘training’, ‘validation’, ‘testing’]]) – The pass mode, which is None in the transductive setting and one of “training”, “validation”, or “testing” in the inductive setting.
- Return type
DataFrame
- Returns
shape: (k, 3) A dataframe with columns based on the settings or a tensor. Contains either the k highest scoring triples, or all possible triples if k is None.
The following example shows that after you train a model on the Nations dataset, you can score all entities w.r.t a given relation and tail entity.
>>> from pykeen.pipeline import pipeline >>> from pykeen.models.predict import get_head_prediction_df >>> result = pipeline( ... dataset='Nations', ... model='RotatE', ... ) >>> df = get_head_prediction_df(result.model, 'accusation', 'brazil', triples_factory=result.training)