"""Implementation of triples splitting functions."""
import logging
import typing
from abc import abstractmethod
from collections.abc import Collection, Sequence
from typing import Optional, Union
import numpy
import pandas
import torch
from class_resolver.api import ClassResolver, HintOrType
from ..constants import COLUMN_LABELS
from ..typing import LABEL_HEAD, LABEL_RELATION, LABEL_TAIL, BoolTensor, MappedTriples, Target, TorchRandomHint
from ..utils import ensure_torch_random_state
logger = logging.getLogger(__name__)
__all__ = [
"split",
# Cleaners
"cleaner_resolver",
"Cleaner",
"RandomizedCleaner",
"DeterministicCleaner",
# Splitters
"splitter_resolver",
"Splitter",
"CleanupSplitter",
"CoverageSplitter",
# Utils
"TripleCoverageError",
"normalize_ratios",
"get_absolute_split_sizes",
]
def _split_triples(
mapped_triples: MappedTriples,
sizes: Sequence[int],
random_state: TorchRandomHint = None,
) -> Sequence[MappedTriples]:
"""
Randomly split triples into groups of given sizes.
:param mapped_triples: shape: (n, 3)
The triples.
:param sizes:
The sizes.
:param random_state:
The random state for reproducible splits.
:return:
The splitted triples.
:raises ValueError:
If the given sizes are different from the number of triples in mapped triples
"""
num_triples = mapped_triples.shape[0]
if sum(sizes) != num_triples:
raise ValueError(f"Received {num_triples} triples, but the sizes sum up to {sum(sizes)}")
# Split indices
idx = torch.randperm(num_triples, generator=random_state)
idx_groups = idx.split(split_size=sizes, dim=0)
# Split triples
triples_groups = [mapped_triples[idx] for idx in idx_groups]
logger.info(
"done splitting triples to groups of sizes %s",
[triples.shape[0] for triples in triples_groups],
)
return triples_groups
def _get_cover_for_column(df: pandas.DataFrame, column: Target, index_column: str = "index") -> set[int]:
return set(df.groupby(by=column).agg({index_column: "min"})[index_column].values)
def _get_covered_entities(df: pandas.DataFrame, chosen: Collection[int]) -> set[int]:
return set(numpy.unique(df.loc[df["index"].isin(chosen), [LABEL_HEAD, LABEL_TAIL]]))
def _get_cover_deterministic(triples: MappedTriples) -> BoolTensor:
"""
Get a coverage mask for all entities and relations.
The implementation uses a greedy coverage algorithm for selecting triples. If there are multiple triples to
choose, the smaller ID is preferred.
1. Select one triple for each relation.
2. Select one triple for each head entity, which is not yet covered.
3. Select one triple for each tail entity, which is not yet covered.
The cover is guaranteed to contain at most $num_relations + num_unique_heads + num_unique_tails$ triples.
:param triples: shape: (n, 3)
The triples (ID-based).
:return: shape: (n,)
A boolean mask indicating whether the triple is part of the cover.
"""
df = pandas.DataFrame(data=triples.numpy(), columns=COLUMN_LABELS).reset_index()
# select one triple per relation
chosen = _get_cover_for_column(df=df, column=LABEL_RELATION)
# Select one triple for each head/tail entity, which is not yet covered.
for column in (LABEL_HEAD, LABEL_TAIL):
covered = _get_covered_entities(df=df, chosen=chosen)
chosen |= _get_cover_for_column(df=df[~df[column].isin(covered)], column=column)
# create mask
num_triples = triples.shape[0]
seed_mask = torch.zeros(num_triples, dtype=torch.bool)
seed_mask[list(chosen)] = True
return seed_mask
[docs]
class TripleCoverageError(RuntimeError):
"""An exception thrown when not all entities/relations are covered by triples."""
def __init__(self, arr, name: str = "ids"):
"""
Initialize the error.
:param arr: shape: (num_indices,)
the array of covering triple IDs
:param name:
the name to use for creating the error message
"""
r = sorted((arr < 0).nonzero(as_tuple=False))
super().__init__(
f"Could not cover the following {name} from the provided triples: {r}. One possible reason is that you are"
f" working with triples from a non-compact ID mapping, i.e. where the IDs do not span the full range of "
f"[0, ..., num_ids - 1]",
)
[docs]
def normalize_ratios(
ratios: Union[float, Sequence[float]],
epsilon: float = 1.0e-06,
) -> tuple[float, ...]:
"""Normalize relative sizes.
If the sum is smaller than 1, adds (1 - sum)
:param ratios:
The ratios.
:param epsilon:
A small constant for comparing sum of ratios against 1.
:return:
A sequence of ratios of at least two elements which sums to one.
:raises ValueError:
if the ratio sum is bigger than 1.0
"""
# Prepare split index
if isinstance(ratios, float):
ratios = [ratios]
ratios = tuple(ratios)
ratio_sum = sum(ratios)
if ratio_sum < 1.0 - epsilon:
ratios = ratios + (1.0 - ratio_sum,)
elif ratio_sum > 1.0 + epsilon:
raise ValueError(f"ratios sum to more than 1.0: {ratios} (sum={ratio_sum})")
return ratios
[docs]
def get_absolute_split_sizes(
n_total: int,
ratios: Sequence[float],
) -> tuple[int, ...]:
"""
Compute absolute sizes of splits from given relative sizes.
.. note ::
This method compensates for rounding errors, and ensures that the absolute sizes sum up to the total number.
:param n_total:
The total number.
:param ratios:
The relative ratios (should sum to 1).
:return:
The absolute sizes.
"""
# due to rounding errors we might lose a few points, thus we use cumulative ratio
cum_ratio = numpy.cumsum(ratios)
cum_ratio[-1] = 1.0
cum_ratio = numpy.r_[numpy.zeros(1), cum_ratio]
split_points = (cum_ratio * n_total).astype(numpy.int64)
sizes = numpy.diff(split_points)
return tuple(sizes)
[docs]
class Cleaner:
"""A cleanup method for ensuring that all entities are contained in the triples of the first split part."""
[docs]
@abstractmethod
def cleanup_pair(
self,
reference: MappedTriples,
other: MappedTriples,
random_state: TorchRandomHint,
) -> tuple[MappedTriples, MappedTriples]:
"""
Clean up one set of triples with respect to a reference set.
:param reference:
the reference set of triples, which shall contain triples for all entities
:param other:
the other set of triples
:param random_state:
the random state to use, if any randomized operations take place
:return:
a pair (reference, other), where some triples of other may have been moved into reference
"""
raise NotImplementedError
[docs]
def __call__(
self,
triples_groups: Sequence[MappedTriples],
random_state: TorchRandomHint,
) -> Sequence[MappedTriples]:
"""Cleanup a list of triples array with respect to the first array."""
reference, *others = triples_groups
result = []
for other in others:
reference, other = self.cleanup_pair(reference=reference, other=other, random_state=random_state)
result.append(other)
return reference, *result
def _prepare_cleanup(
training: MappedTriples,
testing: MappedTriples,
max_ids: Optional[tuple[int, int]] = None,
) -> BoolTensor:
"""
Calculate a mask for the test triples with triples containing test-only entities or relations.
:param training: shape: (n, 3)
The training triples.
:param testing: shape: (m, 3)
The testing triples.
:param max_ids:
The maximum identifier in each column. Calculates it automatically if not given.
:return: shape: (m,)
The move mask.
"""
# base cases
if len(testing) == 0:
return torch.empty(0, dtype=torch.bool)
if len(training) == 0:
return torch.ones(testing.shape[0], dtype=torch.bool)
columns = [[0, 2], [1]]
to_move_mask = torch.zeros(1, dtype=torch.bool)
if max_ids is None:
max_ids = typing.cast(
tuple[int, int],
tuple(max(training[:, col].max().item(), testing[:, col].max().item()) + 1 for col in columns),
)
for col, max_id in zip(columns, max_ids):
# IDs not in training
not_in_training_mask = torch.ones(max_id, dtype=torch.bool)
not_in_training_mask[training[:, col].view(-1)] = False
# triples with exclusive test IDs
exclusive_triples = not_in_training_mask[testing[:, col].view(-1)].view(-1, len(col)).any(dim=-1)
to_move_mask = to_move_mask | exclusive_triples
return to_move_mask
[docs]
class RandomizedCleaner(Cleaner):
"""Cleanup a triples array by randomly selecting testing triples and recalculate to minimize moves.
1. Calculate ``move_id_mask`` as in :func:`_prepare_cleanup`
2. Choose a triple to move, recalculate ``move_id_mask``
3. Continue until ``move_id_mask`` has no true bits
"""
# docstr-coverage: inherited
[docs]
def cleanup_pair(
self,
reference: MappedTriples,
other: MappedTriples,
random_state: TorchRandomHint,
) -> tuple[MappedTriples, MappedTriples]: # noqa: D102
generator = ensure_torch_random_state(random_state)
move_id_mask = _prepare_cleanup(reference, other)
# While there are still triples that should be moved to the training set
while move_id_mask.any():
# Pick a random triple to move over to the training triples
(candidates,) = move_id_mask.nonzero(as_tuple=True)
# TODO: this could easily be extended to select a batch of triples
# -> speeds up the process at the cost of slightly larger movements
idx = torch.randint(candidates.shape[0], size=(1,), generator=generator)
idx = candidates[idx]
# add to training
reference = torch.cat([reference, other[idx].view(1, -1)], dim=0)
# remove from testing
other = torch.cat([other[:idx], other[idx + 1 :]], dim=0)
# Recalculate the move_id_mask
move_id_mask = _prepare_cleanup(reference, other)
return reference, other
[docs]
class DeterministicCleaner(Cleaner):
"""Cleanup a triples array (testing) with respect to another (training)."""
# docstr-coverage: inherited
[docs]
def cleanup_pair(
self,
reference: MappedTriples,
other: MappedTriples,
random_state: TorchRandomHint,
) -> tuple[MappedTriples, MappedTriples]: # noqa: D102
move_id_mask = _prepare_cleanup(reference, other)
reference = torch.cat([reference, other[move_id_mask]])
other = other[~move_id_mask]
return reference, other
#: A resolver for triple cleaners
cleaner_resolver: ClassResolver[Cleaner] = ClassResolver.from_subclasses(base=Cleaner, default=DeterministicCleaner)
[docs]
class Splitter:
"""A method for splitting triples."""
[docs]
@abstractmethod
def split_absolute_size(
self,
mapped_triples: MappedTriples,
sizes: Sequence[int],
random_state: torch.Generator,
) -> Sequence[MappedTriples]:
"""Split triples into clean groups.
This method partitions the triples, i.e., each triple is in exactly one group. Moreover, it ensures that
the first group contains all entities at least once.
:param mapped_triples: shape: (n, 3)
the ID-based triples
:param sizes:
the absolute number of triples for each split part.
:param random_state:
the random state used for splitting
:return:
a sequence of ID-based triples for each split part. the absolute may be different to ensure the constraint.
"""
raise NotImplementedError
[docs]
def split(
self,
*,
mapped_triples: MappedTriples,
ratios: Union[float, Sequence[float]] = 0.8,
random_state: TorchRandomHint = None,
) -> Sequence[MappedTriples]:
"""Split triples into clean groups.
:param mapped_triples: shape: (n, 3)
the ID-based triples
:param random_state:
the random state used to shuffle and split the triples
:param ratios:
There are three options for this argument.
First, a float can be given between 0 and 1.0, non-inclusive. The first set of triples will get this
ratio and the second will get the rest.
Second, a list of ratios can be given for which set in which order should get what ratios as
in ``[0.8, 0.1]``.
The final ratio can be omitted because that can be calculated.
Third, all ratios can be explicitly set in order such as in ``[0.8, 0.1, 0.1]`` where the sum of
all ratios is 1.0.
:return:
A partition of triples, which are split (approximately) according to the ratios.
"""
random_state = ensure_torch_random_state(random_state)
ratios = normalize_ratios(ratios=ratios)
sizes = get_absolute_split_sizes(n_total=mapped_triples.shape[0], ratios=ratios)
triples_groups = self.split_absolute_size(
mapped_triples=mapped_triples,
sizes=sizes,
random_state=random_state,
)
for i, (triples, exp_size, exp_ratio) in enumerate(zip(triples_groups, sizes, ratios)):
actual_size = triples.shape[0]
actual_ratio = actual_size / exp_size * exp_ratio
if actual_size != exp_size:
logger.warning(
f"Requested ratio[{i}]={exp_ratio:.3f} (equal to size {exp_size}), but got {actual_ratio:.3f} "
f"(equal to size {actual_size}) to ensure that all entities/relations occur in train.",
)
return triples_groups
[docs]
class CleanupSplitter(Splitter):
"""
The cleanup splitter first randomly splits the triples and then cleans up.
In the cleanup process, triples are moved into the train part until all entities occur at least once in train.
The splitter supports two variants of cleanup, cf. ``cleaner_resolver``.
"""
def __init__(self, cleaner: HintOrType[Cleaner] = None) -> None:
"""
Initialize the splitter.
:param cleaner:
the cleanup method to use. Defaults to the fast deterministic cleaner,
which may lead to larger deviances between desired and actual triple count.
"""
self.cleaner = cleaner_resolver.make(cleaner)
# docstr-coverage: inherited
[docs]
def split_absolute_size(
self,
mapped_triples: MappedTriples,
sizes: Sequence[int],
random_state: torch.Generator,
) -> Sequence[MappedTriples]: # noqa: D102
triples_groups = _split_triples(
mapped_triples,
sizes=sizes,
random_state=random_state,
)
# Make sure that the first element has all the right stuff in it
logger.debug("cleaning up groups")
triples_groups = self.cleaner(triples_groups, random_state=random_state)
logger.debug("done cleaning up groups")
return triples_groups
[docs]
class CoverageSplitter(Splitter):
"""This splitter greedily selects training triples such that each entity is covered and then splits the rest."""
# docstr-coverage: inherited
[docs]
def split_absolute_size(
self,
mapped_triples: MappedTriples,
sizes: Sequence[int],
random_state: torch.Generator,
) -> Sequence[MappedTriples]: # noqa: D102
seed_mask = _get_cover_deterministic(triples=mapped_triples)
train_seed = mapped_triples[seed_mask]
remaining_triples = mapped_triples[~seed_mask]
if train_seed.shape[0] > sizes[0]:
raise ValueError(f"Could not find a coverage of all entities and relation with only {sizes[0]} triples.")
remaining_sizes = (sizes[0] - train_seed.shape[0],) + tuple(sizes[1:])
train, *rest = _split_triples(
mapped_triples=remaining_triples,
sizes=remaining_sizes,
random_state=random_state,
)
return [torch.cat([train_seed, train], dim=0), *rest]
#: A resolver for triple splitters
splitter_resolver: ClassResolver[Splitter] = ClassResolver.from_subclasses(base=Splitter, default=CoverageSplitter)
[docs]
def split(
mapped_triples: MappedTriples,
ratios: Union[float, Sequence[float]] = 0.8,
random_state: TorchRandomHint = None,
randomize_cleanup: bool = False,
method: Optional[str] = None,
) -> Sequence[MappedTriples]:
"""Split triples into clean groups.
:param mapped_triples: shape: (n, 3)
The ID-based triples.
:param ratios:
There are three options for this argument.
First, a float can be given between 0 and 1.0, non-inclusive. The first set of triples will get this ratio and
the second will get the rest.
Second, a list of ratios can be given for which set in which order should get what ratios as in ``[0.8, 0.1]``.
The final ratio can be omitted because that can be calculated.
Third, all ratios can be explicitly set in order such as in ``[0.8, 0.1, 0.1]`` where the sum of all ratios is
1.0.
:param random_state:
The random state used to shuffle and split the triples.
:param randomize_cleanup:
If true, uses the non-deterministic method for moving triples to the training set. This has the advantage that
it does not necessarily have to move all of them, but it might be significantly slower since it moves one
triple at a time.
:param method:
The name of the method to use, cf. :data:`splitter_resolver`. Defaults to "coverage", i.e.,
:class:`CoverageSplitter`.
:return:
A partition of triples, which are split (approximately) according to the ratios.
.. code-block:: python
ratio = 0.8 # makes a [0.8, 0.2] split
train, test = split(triples, ratio)
ratios = [0.8, 0.1] # makes a [0.8, 0.1, 0.1] split
train, test, val = split(triples, ratios)
ratios = [0.8, 0.1, 0.1] # also makes a [0.8, 0.1, 0.1] split
train, test, val = split(triples, ratios)
"""
# backwards compatibility
splitter_cls: type[Splitter] = splitter_resolver.lookup(method)
kwargs = dict()
if splitter_cls is CleanupSplitter and randomize_cleanup:
kwargs["cleaner"] = cleaner_resolver.normalize_cls(RandomizedCleaner)
return splitter_resolver.make(splitter_cls, pos_kwargs=kwargs).split(
mapped_triples=mapped_triples,
ratios=ratios,
random_state=random_state,
)