ClassificationEvaluator

class ClassificationEvaluator(**kwargs)[source]

Bases: pykeen.evaluation.evaluator.Evaluator

An evaluator that uses a classification metrics.

Initialize the evaluator.

Parameters
  • filtered – Should filtered evaluation be performed?

  • requires_positive_mask – Does the evaluator need access to the masks?

  • batch_size – >0. Evaluation batch size.

  • slice_size – >0. The divisor for the scoring function when using slicing

  • automatic_memory_optimization – Whether to automatically optimize the sub-batch size during evaluation with regards to the hardware at hand.

Methods Summary

finalize()

Compute the final results, and clear buffers.

process_head_scores_(hrt_batch, true_scores, ...)

Process a batch of triples with their computed head scores for all entities.

process_tail_scores_(hrt_batch, true_scores, ...)

Process a batch of triples with their computed tail scores for all entities.

Methods Documentation

finalize()[source]

Compute the final results, and clear buffers.

Return type

ClassificationMetricResults

process_head_scores_(hrt_batch, true_scores, scores, dense_positive_mask=None)[source]

Process a batch of triples with their computed head scores for all entities.

Parameters
  • hrt_batch (LongTensor) – shape: (batch_size, 3)

  • true_scores (FloatTensor) – shape: (batch_size)

  • scores (FloatTensor) – shape: (batch_size, num_entities)

  • dense_positive_mask (Optional[FloatTensor]) – shape: (batch_size, num_entities) An optional binary (0/1) tensor indicating other true entities.

Return type

None

process_tail_scores_(hrt_batch, true_scores, scores, dense_positive_mask=None)[source]

Process a batch of triples with their computed tail scores for all entities.

Parameters
  • hrt_batch (LongTensor) – shape: (batch_size, 3)

  • true_scores (FloatTensor) – shape: (batch_size)

  • scores (FloatTensor) – shape: (batch_size, num_entities)

  • dense_positive_mask (Optional[FloatTensor]) – shape: (batch_size, num_entities) An optional binary (0/1) tensor indicating other true entities.

Return type

None