# -*- coding: utf-8 -*-

"""Basic structure for a negative sampler."""

from abc import ABC, abstractmethod
from typing import Any, ClassVar, Mapping, Optional

import torch

from ..triples import TriplesFactory
from ..utils import normalize_string

[docs]class NegativeSampler(ABC): """A negative sampler.""" #: The default strategy for optimizing the negative sampler's hyper-parameters hpo_default: ClassVar[Mapping[str, Mapping[str, Any]]] def __init__( self, triples_factory: TriplesFactory, num_negs_per_pos: Optional[int] = None, ) -> None: """Initialize the negative sampler with the given entities. :param triples_factory: The factory holding the triples to sample from :param num_negs_per_pos: Number of negative samples to make per positive triple. Defaults to 1. """ self.triples_factory = triples_factory self.num_negs_per_pos = num_negs_per_pos if num_negs_per_pos is not None else 1 @classmethod def get_normalized_name(cls) -> str: """Get the normalized name of the negative sampler.""" return normalize_string(cls.__name__, suffix=NegativeSampler.__name__) @property def num_entities(self) -> int: # noqa: D401 """The number of entities to sample from.""" return self.triples_factory.num_entities
[docs] @abstractmethod def sample(self, positive_batch: torch.LongTensor) -> torch.LongTensor: """Generate negative samples from the positive batch.""" raise NotImplementedError