PartiallyRestrictedPredictionDataset
- class PartiallyRestrictedPredictionDataset(*, heads: Tensor | Collection[int] | int | None = None, relations: Tensor | Collection[int] | int | None = None, tails: Tensor | Collection[int] | int | None = None, target: Literal['head', 'relation', 'tail'] = 'tail')[source]
Bases:
PredictionDataset
A dataset for scoring some links.
“Some links” is defined as
\[\mathcal{T}_{interest} = \mathcal{E}_{h} \times \mathcal{R}_{r} \times \mathcal{E}_{t}\]Note
For now, the target, i.e., position whose prediction method in the model is utilized, must be the full set of entities/relations.
Example .. code-block:: python
# train model; note: needs larger number of epochs to do something useful ;-) from pykeen.pipeline import pipeline result = pipeline(dataset=”nations”, model=”mure”, training_kwargs=dict(num_epochs=0))
# create prediction dataset, where the head entities is from a set of European countries, # and the relations are connected to tourism from pykeen.predict import PartiallyRestrictedPredictionDataset heads = result.training.entities_to_ids(entities=[“netherlands”, “poland”, “uk”]) relations = result.training.relations_to_ids(relations=[“reltourism”, “tourism”, “tourism3”]) dataset = PartiallyRestrictedPredictionDataset(heads=heads, relations=relations)
# calculate all scores for this restricted set, and keep k=3 largest from pykeen.predict import consume_scores, TopKScoreConsumer consumer = TopKScoreConsumer(k=3) consume_scores(result.model, ds, consumer) score_pack = consumer.finalize()
# add labels df = result.training.tensor_to_df(score_pack.result, score=score_pack.scores)
Initialize restricted prediction dataset.
- Parameters:
heads (Tensor | Collection[int] | int | None) – the restricted head entities
relations (Tensor | Collection[int] | int | None) – the restricted relations
tails (Tensor | Collection[int] | int | None) – the restricted tails
target (Literal['head', 'relation', 'tail']) – the prediction target
- Raises:
NotImplementedError – if the target position is restricted, or any non-target position is not restricted