Source code for pykeen.triples.instances

"""Implementation of basic instance factory which creates just instances based on standard KG triples."""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator
from typing import Generic, NamedTuple, TypeVar

import numpy as np
import scipy.sparse
import torch
from class_resolver import HintOrType, OptionalKwargs
from torch.utils import data

from .utils import compute_compressed_adjacency_list
from ..sampling import NegativeSampler, negative_sampler_resolver
from ..typing import BoolTensor, FloatTensor, LongTensor, MappedTriples
from ..utils import split_workload

__all__ = [
    "Instances",
    "SLCWAInstances",
    "LCWAInstances",
]

# TODO: the same
SampleType = TypeVar("SampleType")
BatchType = TypeVar("BatchType")
LCWASampleType = tuple[MappedTriples, FloatTensor]
LCWABatchType = tuple[MappedTriples, FloatTensor]
SLCWASampleType = tuple[MappedTriples, MappedTriples, BoolTensor | None]


class SLCWABatch(NamedTuple):
    """A batch for sLCWA training."""

    #: the positive triples, shape: (batch_size, 3)
    positives: LongTensor

    #: the negative triples, shape: (batch_size, num_negatives_per_positive, 3)
    negatives: LongTensor

    #: filtering masks for negative triples, shape: (batch_size, num_negatives_per_positive)
    masks: BoolTensor | None


[docs] class Instances(data.Dataset[BatchType], Generic[SampleType, BatchType], ABC): """Base class for training instances.""" def __len__(self): # noqa:D401 """Get the number of instances.""" raise NotImplementedError
[docs] def get_collator(self) -> Callable[[list[SampleType]], BatchType] | None: """Get a collator.""" return None
[docs] @classmethod def from_triples( cls, mapped_triples: MappedTriples, *, num_entities: int, num_relations: int, **kwargs, ) -> Instances: """Create instances from mapped triples. :param mapped_triples: shape: (num_triples, 3) The ID-based triples. :param num_entities: >0 The number of entities. :param num_relations: >0 The number of relations. :param kwargs: additional keyword-based parameters. :returns: The instances. # noqa:DAR201 # noqa:DAR202 # noqa:DAR401 """ raise NotImplementedError
[docs] class SLCWAInstances(Instances[SLCWASampleType, SLCWABatch]): """Training instances for the sLCWA.""" def __init__( self, *, mapped_triples: MappedTriples, num_entities: int | None = None, num_relations: int | None = None, negative_sampler: HintOrType[NegativeSampler] = None, negative_sampler_kwargs: OptionalKwargs = None, ): """Initialize the sLCWA instances. :param mapped_triples: shape: (num_triples, 3) the ID-based triples, passed to the negative sampler :param num_entities: >0 the number of entities, passed to the negative sampler :param num_relations: >0 the number of relations, passed to the negative sampler :param negative_sampler: the negative sampler, or a hint thereof :param negative_sampler_kwargs: additional keyword-based arguments passed to the negative sampler """ self.mapped_triples = mapped_triples self.sampler = negative_sampler_resolver.make( negative_sampler, pos_kwargs=negative_sampler_kwargs, mapped_triples=mapped_triples, num_entities=num_entities, num_relations=num_relations, ) def __len__(self) -> int: # noqa: D105 return self.mapped_triples.shape[0] def __getitem__(self, item: int) -> SLCWASampleType: # noqa: D105 positive = self.mapped_triples[item].unsqueeze(dim=0) # TODO: some negative samplers require batches negative, mask = self.sampler.sample(positive_batch=positive) # shape: (1, 3), (1, k, 3), (1, k, 3)? return positive, negative, mask
[docs] @staticmethod def collate(samples: Iterable[SLCWASampleType]) -> SLCWABatch: """Collate samples.""" # each shape: (1, 3), (1, k, 3), (1, k, 3)? masks: LongTensor | None positives, negatives, masks = zip(*samples, strict=False) positives = torch.cat(positives, dim=0) negatives = torch.cat(negatives, dim=0) mask_batch: BoolTensor | None if masks[0] is None: assert all(m is None for m in masks) mask_batch = None else: mask_batch = torch.cat(masks, dim=0) return SLCWABatch(positives, negatives, mask_batch)
# docstr-coverage: inherited
[docs] def get_collator(self) -> Callable[[list[SLCWASampleType]], SLCWABatch] | None: # noqa: D102 return self.collate
# docstr-coverage: inherited
[docs] @classmethod def from_triples( cls, mapped_triples: MappedTriples, *, num_entities: int, num_relations: int, **kwargs, ) -> Instances: # noqa: D102 return cls(mapped_triples=mapped_triples, num_entities=num_entities, num_relations=num_relations, **kwargs)
class BaseBatchedSLCWAInstances(data.IterableDataset[SLCWABatch]): """Pre-batched training instances for the sLCWA training loop. .. note:: this class is intended to be used with automatic batching disabled, i.e., both parameters `batch_size` and `batch_sampler` of torch.utils.data.DataLoader` are set to `None`. """ def __init__( self, mapped_triples: MappedTriples, batch_size: int = 1, drop_last: bool = True, num_entities: int | None = None, num_relations: int | None = None, negative_sampler: HintOrType[NegativeSampler] = None, negative_sampler_kwargs: OptionalKwargs = None, ): """Initialize the dataset. :param mapped_triples: shape: (num_triples, 3) the mapped triples :param batch_size: the batch size :param drop_last: whether to drop the last (incomplete) batch :param num_entities: >0 the number of entities, passed to the negative sampler :param num_relations: >0 the number of relations, passed to the negative sampler :param negative_sampler: the negative sampler, or a hint thereof :param negative_sampler_kwargs: additional keyword-based parameters used to instantiate the negative sampler """ self.mapped_triples = mapped_triples self.batch_size = batch_size self.drop_last = drop_last self.negative_sampler = negative_sampler_resolver.make( negative_sampler, pos_kwargs=negative_sampler_kwargs, mapped_triples=self.mapped_triples, num_entities=num_entities, num_relations=num_relations, ) def __getitem__(self, item: list[int]) -> SLCWABatch: """Get a batch from the given list of positive triple IDs.""" positive_batch = self.mapped_triples[item] negative_batch, masks = self.negative_sampler.sample(positive_batch=positive_batch) return SLCWABatch(positives=positive_batch, negatives=negative_batch, masks=masks) @abstractmethod def iter_triple_ids(self) -> Iterable[list[int]]: """Iterate over batches of IDs of positive triples.""" raise NotImplementedError def __iter__(self) -> Iterator[SLCWABatch]: """Iterate over batches.""" for triple_ids in self.iter_triple_ids(): yield self[triple_ids] def __len__(self) -> int: """Return the number of batches.""" num_batches, remainder = divmod(len(self.mapped_triples), self.batch_size) if remainder and not self.drop_last: num_batches += 1 return num_batches class BatchedSLCWAInstances(BaseBatchedSLCWAInstances): """Random pre-batched training instances for the sLCWA training loop.""" # docstr-coverage: inherited def iter_triple_ids(self) -> Iterable[list[int]]: # noqa: D102 yield from data.BatchSampler( sampler=data.RandomSampler(data_source=split_workload(len(self.mapped_triples))), batch_size=self.batch_size, drop_last=self.drop_last, ) class SubGraphSLCWAInstances(BaseBatchedSLCWAInstances): """Pre-batched training instances for SLCWA of coherent subgraphs.""" def __init__(self, **kwargs): """Initialize the instances. :param kwargs: keyword-based parameters passed to :meth:`BaseBatchedSLCWAInstances.__init__` """ super().__init__(**kwargs) # indexing self.degrees, self.offset, self.neighbors = compute_compressed_adjacency_list( mapped_triples=self.mapped_triples ) def subgraph_sample(self) -> list[int]: """Sample one subgraph.""" # initialize node_weights = self.degrees.detach().clone() edge_picked = torch.zeros(self.mapped_triples.shape[0], dtype=torch.bool) node_picked = torch.zeros(self.degrees.shape[0], dtype=torch.bool) # sample iteratively result = [] for _ in range(self.batch_size): # determine weights weights = node_weights * node_picked if torch.sum(weights) == 0: # randomly choose a vertex which has not been chosen yet pool = (~node_picked).nonzero() chosen_vertex = pool[torch.randint(pool.numel(), size=tuple())] else: # normalize to probabilities probabilities = weights.float() / weights.sum().float() chosen_vertex = torch.multinomial(probabilities, num_samples=1)[0] # sample a start node node_picked[chosen_vertex] = True # get list of neighbors start = self.offset[chosen_vertex] chosen_node_degree = self.degrees[chosen_vertex].item() stop = start + chosen_node_degree adj_list = self.neighbors[start:stop, :] # sample an outgoing edge at random which has not been chosen yet using rejection sampling chosen_edge_index = torch.randint(chosen_node_degree, size=(1,))[0] chosen_edge = adj_list[chosen_edge_index] edge_number = chosen_edge[0] while edge_picked[edge_number]: chosen_edge_index = torch.randint(chosen_node_degree, size=(1,))[0] chosen_edge = adj_list[chosen_edge_index] edge_number = chosen_edge[0] result.append(edge_number.item()) edge_picked[edge_number] = True # visit target node other_vertex = chosen_edge[1] node_picked[other_vertex] = True # decrease sample counts node_weights[chosen_vertex] -= 1 node_weights[other_vertex] -= 1 return result # docstr-coverage: inherited def iter_triple_ids(self) -> Iterable[list[int]]: # noqa: D102 yield from (self.subgraph_sample() for _ in split_workload(len(self)))
[docs] class LCWAInstances(Instances[LCWASampleType, LCWABatchType]): """Triples and mappings to their indices for LCWA.""" def __init__(self, *, pairs: np.ndarray, compressed: scipy.sparse.csr_matrix): """Initialize the LCWA instances. :param pairs: The unique pairs :param compressed: The compressed triples in CSR format """ self.pairs = pairs self.compressed = compressed
[docs] @classmethod def from_triples( cls, mapped_triples: MappedTriples, *, num_entities: int, num_relations: int, target: int | None = None, **kwargs, ) -> Instances: """Create LCWA instances from triples. :param mapped_triples: shape: (num_triples, 3) The ID-based triples. :param num_entities: The number of entities. :param num_relations: The number of relations. :param target: The column to predict :param kwargs: Keyword arguments (thrown out) :returns: The instances. """ if target is None: target = 2 mapped_triples = mapped_triples.numpy() other_columns = sorted(set(range(3)).difference({target})) unique_pairs, pair_idx_to_triple_idx = np.unique(mapped_triples[:, other_columns], return_inverse=True, axis=0) num_pairs = unique_pairs.shape[0] tails = mapped_triples[:, target] target_size = num_relations if target == 1 else num_entities compressed = scipy.sparse.coo_matrix( (np.ones(mapped_triples.shape[0], dtype=np.float32), (pair_idx_to_triple_idx, tails)), shape=(num_pairs, target_size), ) # convert to csr for fast row slicing compressed = compressed.tocsr() return cls(pairs=unique_pairs, compressed=compressed)
def __len__(self) -> int: # noqa: D105 return self.pairs.shape[0] def __getitem__(self, item: int) -> LCWABatchType: # noqa: D105 return self.pairs[item], np.asarray(self.compressed[item, :].todense())[0, :]