Source code for pykeen.triples.instances

"""Implementation of basic instance factory which creates just instances based on standard KG triples."""

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from typing import Generic, TypedDict, TypeVar

import numpy as np
import scipy.sparse
import torch
from class_resolver import HintOrType, OptionalKwargs, ResolverKey, update_docstring_with_resolver_keys
from torch.utils import data
from typing_extensions import NotRequired, Self

from .triples_factory import CoreTriplesFactory
from .utils import compute_compressed_adjacency_list
from .weights import LossWeighter, loss_weighter_resolver
from .. import typing as pykeen_typing
from ..constants import get_target_column
from ..sampling import NegativeSampler, negative_sampler_resolver
from ..typing import (
    BoolTensor,
    FloatTensor,
    LongTensor,
    MappedTriples,
    TargetColumn,
    TargetHint,
)
from ..utils import split_workload

__all__ = [
    "Instances",
    "LCWAInstances",
    "BaseBatchedSLCWAInstances",
    "BatchedSLCWAInstances",
    "SubGraphSLCWAInstances",
    "LCWABatch",
    "SLCWABatch",
]

BatchType = TypeVar("BatchType")


[docs] class LCWABatch(TypedDict): """A batch for LCWA training.""" pairs: LongTensor target: FloatTensor weights: NotRequired[FloatTensor] """Sample weights."""
[docs] class SLCWABatch(TypedDict): """A batch for sLCWA training.""" # TODO: separately storing head/relation/tail corruptions would enable faster scoring (and thus training) #: the positive triples, shape: (batch_size, 3) positives: LongTensor #: the negative triples, shape: (batch_size, num_negatives_per_positive, 3) negatives: LongTensor #: filtering masks for negative triples, shape: (batch_size, num_negatives_per_positive) masks: NotRequired[BoolTensor] #: sample weights pos_weights: NotRequired[FloatTensor] neg_weights: NotRequired[FloatTensor]
[docs] class Instances(data.Dataset[BatchType], Generic[BatchType], ABC): """Base class for training instances.""" @abstractmethod def __len__(self): """Get the number of instances.""" raise NotImplementedError
[docs] class BaseBatchedSLCWAInstances(Instances[SLCWABatch], data.IterableDataset[SLCWABatch]): """Pre-batched training instances for the sLCWA training loop. .. note:: this class is intended to be used with automatic batching disabled, i.e., both parameters `batch_size` and `batch_sampler` of torch.utils.data.DataLoader` are set to `None`. """ @update_docstring_with_resolver_keys( ResolverKey("negative_sampler", "pykeen.sampling.negative_sampler_resolver"), ResolverKey("loss_weighter", "pykeen.triples.weights.loss_weighter_resolver"), ) def __init__( self, mapped_triples: MappedTriples, batch_size: int = 1, drop_last: bool = True, num_entities: int | None = None, num_relations: int | None = None, negative_sampler: HintOrType[NegativeSampler] = None, negative_sampler_kwargs: OptionalKwargs = None, loss_weighter: HintOrType[LossWeighter] = None, loss_weighter_kwargs: OptionalKwargs = None, ): """Initialize the dataset. :param mapped_triples: shape: (num_triples, 3) the mapped triples :param batch_size: the batch size :param drop_last: whether to drop the last (incomplete) batch :param num_entities: >0 the number of entities, passed to the negative sampler :param num_relations: >0 the number of relations, passed to the negative sampler :param negative_sampler: the negative sampler, or a hint thereof :param negative_sampler_kwargs: additional keyword-based parameters used to instantiate the negative sampler :param loss_weighter: The method to determine sample weights. :param loss_weighter_kwargs: Parameters for the method to determine sample weights. """ self.mapped_triples = mapped_triples self.batch_size = batch_size self.drop_last = drop_last self.negative_sampler = negative_sampler_resolver.make( negative_sampler, pos_kwargs=negative_sampler_kwargs, mapped_triples=self.mapped_triples, num_entities=num_entities, num_relations=num_relations, ) self.loss_weighter = loss_weighter_resolver.make_safe(loss_weighter, loss_weighter_kwargs) def __getitem__(self, item: list[int]) -> SLCWABatch: """Get a batch from the given list of positive triple IDs.""" positive_batch = self.mapped_triples[item] negative_batch, masks = self.negative_sampler.sample(positive_batch=positive_batch) result = SLCWABatch(positives=positive_batch, negatives=negative_batch) if masks is not None: result["masks"] = masks if self.loss_weighter is not None: result["pos_weights"] = self.loss_weighter.weight_triples(positive_batch) result["neg_weights"] = self.loss_weighter.weight_triples(negative_batch) return result
[docs] @abstractmethod def iter_triple_ids(self) -> Iterable[list[int]]: """Iterate over batches of IDs of positive triples.""" raise NotImplementedError
def __iter__(self) -> Iterator[SLCWABatch]: """Iterate over batches.""" for triple_ids in self.iter_triple_ids(): yield self[triple_ids] def __len__(self) -> int: """Return the number of batches.""" num_batches, remainder = divmod(len(self.mapped_triples), self.batch_size) if remainder and not self.drop_last: num_batches += 1 return num_batches
[docs] @classmethod def from_triples_factory(cls, tf: CoreTriplesFactory, **kwargs) -> Self: """Create sLCWA instances for triples factory.""" # TODO: can we better type `kwargs`? if "shuffle" in kwargs: if kwargs.pop("shuffle"): warnings.warn("Training instances are always shuffled.", DeprecationWarning, stacklevel=2) else: raise AssertionError("If shuffle is provided, it must be True.") if kwargs.pop("sampler", None): raise AssertionError("sampler is not handled in sLCWA instances") return cls( mapped_triples=tf._add_inverse_triples_if_necessary(mapped_triples=tf.mapped_triples), num_entities=tf.num_entities, num_relations=tf.num_relations, **kwargs, )
[docs] class BatchedSLCWAInstances(BaseBatchedSLCWAInstances): """Random pre-batched training instances for the sLCWA training loop.""" # docstr-coverage: inherited
[docs] def iter_triple_ids(self) -> Iterable[list[int]]: # noqa: D102 yield from data.BatchSampler( sampler=data.RandomSampler(data_source=split_workload(len(self.mapped_triples))), batch_size=self.batch_size, drop_last=self.drop_last, )
[docs] class SubGraphSLCWAInstances(BaseBatchedSLCWAInstances): """Pre-batched training instances for SLCWA of coherent subgraphs.""" def __init__(self, **kwargs): """Initialize the instances. :param kwargs: keyword-based parameters passed to :meth:`BaseBatchedSLCWAInstances.__init__` """ super().__init__(**kwargs) # indexing self.degrees, self.offset, self.neighbors = compute_compressed_adjacency_list( mapped_triples=self.mapped_triples )
[docs] def subgraph_sample(self) -> list[int]: """Sample one subgraph.""" # initialize node_weights = self.degrees.detach().clone() edge_picked = torch.zeros(self.mapped_triples.shape[0], dtype=torch.bool) node_picked = torch.zeros(self.degrees.shape[0], dtype=torch.bool) # sample iteratively result = [] for _ in range(self.batch_size): # determine weights weights = node_weights * node_picked if torch.sum(weights) == 0: # randomly choose a vertex which has not been chosen yet pool = (~node_picked).nonzero() chosen_vertex = pool[torch.randint(pool.numel(), size=tuple())] else: # normalize to probabilities probabilities = weights.float() / weights.sum().float() chosen_vertex = torch.multinomial(probabilities, num_samples=1)[0] # sample a start node node_picked[chosen_vertex] = True # get list of neighbors start = self.offset[chosen_vertex] chosen_node_degree = self.degrees[chosen_vertex].item() stop = start + chosen_node_degree adj_list = self.neighbors[start:stop, :] # sample an outgoing edge at random which has not been chosen yet using rejection sampling chosen_edge_index = torch.randint(chosen_node_degree, size=(1,))[0] chosen_edge = adj_list[chosen_edge_index] edge_number = chosen_edge[0] while edge_picked[edge_number]: chosen_edge_index = torch.randint(chosen_node_degree, size=(1,))[0] chosen_edge = adj_list[chosen_edge_index] edge_number = chosen_edge[0] result.append(edge_number.item()) edge_picked[edge_number] = True # visit target node other_vertex = chosen_edge[1] node_picked[other_vertex] = True # decrease sample counts node_weights[chosen_vertex] -= 1 node_weights[other_vertex] -= 1 return result
# docstr-coverage: inherited
[docs] def iter_triple_ids(self) -> Iterable[list[int]]: # noqa: D102 yield from (self.subgraph_sample() for _ in split_workload(len(self)))
[docs] class LCWAInstances(Instances[LCWABatch]): """Triples and mappings to their indices for LCWA.""" @update_docstring_with_resolver_keys(ResolverKey("loss_weighter", "pykeen.triples.weights.loss_weighter_resolver")) def __init__( self, *, pairs: np.ndarray, compressed: scipy.sparse.csr_matrix, target: TargetHint = None, loss_weighter: HintOrType[LossWeighter] = None, loss_weighter_kwargs: OptionalKwargs = None, ): """Initialize the LCWA instances. :param pairs: The unique pairs :param compressed: The compressed triples in CSR format :param target: The prediction target. :param loss_weighter: The method to determine sample weights. :param loss_weighter_kwargs: Parameters for the method to determine sample weights. """ self.pairs = pairs self.compressed = compressed self.loss_weighter = loss_weighter_resolver.make_safe(loss_weighter, loss_weighter_kwargs) self.target: TargetColumn = get_target_column(target=target)
[docs] @classmethod def from_triples( cls, mapped_triples: MappedTriples, *, num_entities: int, num_relations: int, target: TargetHint = None, **kwargs, ) -> Self: """Create LCWA instances from triples. :param mapped_triples: shape: (num_triples, 3) The ID-based triples. :param num_entities: The number of entities. :param num_relations: The number of relations. :param target: The column to predict :param kwargs: Additional keyword-based parameters passed to :meth:`__init__` :returns: The instances. """ target = get_target_column(target) mapped_triples = mapped_triples.numpy() other_columns = sorted(set(range(3)).difference({target})) unique_pairs, pair_idx_to_triple_idx = np.unique(mapped_triples[:, other_columns], return_inverse=True, axis=0) num_pairs = unique_pairs.shape[0] tails = mapped_triples[:, target] target_size = num_relations if target == 1 else num_entities compressed = scipy.sparse.coo_matrix( (np.ones(mapped_triples.shape[0], dtype=np.float32), (pair_idx_to_triple_idx, tails)), shape=(num_pairs, target_size), ) # convert to csr for fast row slicing compressed = compressed.tocsr() return cls(pairs=unique_pairs, compressed=compressed, target=target, **kwargs)
[docs] @classmethod def from_triples_factory(cls, tf: CoreTriplesFactory, **kwargs) -> Self: """Create LCWA instances for triples factory. :param tf: The triples factory. :param kwargs: Additional keyword-based parameters passed to :meth:`from_triples` :returns: The instances. """ return cls.from_triples( mapped_triples=tf._add_inverse_triples_if_necessary(mapped_triples=tf.mapped_triples), num_entities=tf.num_entities, num_relations=tf.num_relations, **kwargs, )
def __len__(self) -> int: # noqa: D105 return self.pairs.shape[0] def __getitem__(self, item: int) -> LCWABatch: # noqa: D105 pairs = self.pairs[item] result = LCWABatch(pairs=pairs, target=torch.from_numpy(np.asarray(self.compressed[item, :].todense())[0, :])) if self.loss_weighter is None: return result x = pairs[..., None, 0] y = pairs[..., None, 1] match self.target: # note: we need qualification here case pykeen_typing.COLUMN_HEAD: result["weights"] = self.loss_weighter(h=None, r=x, t=y) case pykeen_typing.COLUMN_RELATION: result["weights"] = self.loss_weighter(h=x, r=None, t=y) case pykeen_typing.COLUMN_TAIL: result["weights"] = self.loss_weighter(h=x, r=y, t=None) return result