get_head_prediction_df
- get_head_prediction_df(model, triples_factory, relation_label, tail_label, *, heads=None, **kwargs)[source]
Predict heads for the given relation and tail (given by label).
- Parameters
model (
Model
) – A PyKEEN modeltriples_factory (
TriplesFactory
) – the training triples factoryrelation_label (
str
) – the string label for the relationtail_label (
str
) – the string label for the tail entityheads (
Optional
[Sequence
[str
]]) – restrict head prediction to the given entitieskwargs – additional keyword-based parameters passed to
get_prediction_df()
.
- Return type
DataFrame
- Returns
shape: (k, 3) A dataframe for head predictions. Contains either the k highest scoring triples, or all possible triples if k is None
The following example shows that after you train a model on the Nations dataset, you can score all entities w.r.t. a given relation and tail entity.
>>> from pykeen.pipeline import pipeline >>> from pykeen.models.predict import get_head_prediction_df >>> result = pipeline( ... dataset='Nations', ... model='RotatE', ... ) >>> df = get_head_prediction_df(result.model, 'accusation', 'brazil', triples_factory=result.training)