LCWATrainingLoop

class LCWATrainingLoop(*, target=None, **kwargs)[source]

Bases: TrainingLoop[Tuple[LongTensor, FloatTensor], Tuple[LongTensor, FloatTensor]]

A training loop that is based upon the local closed world assumption (LCWA).

Under the LCWA, for a given true training triple \((h, r, t) \in \mathcal{T}_{train}\), all triples \((h, r, t') \notin \mathcal{T}_{train}\) are assumed to be false. The training approach thus uses a 1-n scoring, where it efficiently computes scores for all triples \((h, r, t')\) for \(t' \in \mathcal{E}\), i.e., sharing the same (head, relation)-pair.

This implementation slightly generalizes the original LCWA, and allows to make the same assumption for relation, or head entity. In particular the second, i.e., predicting the relation, is commonly encountered in visual relation prediction.

[ruffinelli2020] call the LCWA KvsAll in their work.

Initialize the training loop.

Parameters
  • target (Union[None, int, str]) – The target column. From {0, 1, 2} for head/relation/tail prediction. Defaults to 2, i.e., tail prediction.

  • kwargs – Additional keyword-based parameters passed to TrainingLoop.__init__

Raises

ValueError – If an invalid target column is given