"""Combination strategies for entity alignment datasets."""
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, ClassVar, Generic, NamedTuple, TypeVar
import numpy
import pandas
import torch
from class_resolver import ClassResolver
from pandas.api.types import is_numeric_dtype, is_string_dtype
from ...triples import CoreTriplesFactory, TriplesFactory
from ...triples.utils import get_num_ids
from ...typing import (
COLUMN_HEAD,
COLUMN_TAIL,
EA_SIDE_LEFT,
EA_SIDE_RIGHT,
EA_SIDES,
LongTensor,
MappedTriples,
TargetColumn,
)
from ...utils import format_relative_comparison, get_connected_components
__all__ = [
# Abstract class
"GraphPairCombinator",
# Concrete classes
"DisjointGraphPairCombinator",
"SwapGraphPairCombinator",
"ExtraRelationGraphPairCombinator",
"CollapseGraphPairCombinator",
# Data Structures
"ProcessedTuple",
]
logger = logging.getLogger(__name__)
def cat_shift_triples(*triples: CoreTriplesFactory | MappedTriples) -> tuple[MappedTriples, LongTensor]:
"""
Concatenate (shifted) triples.
:param triples:
the triples factories, or mapped triples
:return:
a tuple `(combined_triples, offsets)`, where
* `combined_triples`, shape: `(sum(map(len, triples)), 3)`, is the concatenation of the shifted mapped triples
such that there is no overlap, and
* `offsets`, shape: `(len(triples), 2)` comprises the entity & relation offsets for the individual factories
"""
# a buffer for the triples
res = []
# the offsets
offsets = torch.zeros(len(triples), 2, dtype=torch.long)
for i, x in enumerate(triples):
# normalization
if isinstance(x, CoreTriplesFactory):
e_offset = x.num_entities
r_offset = x.num_relations
x = x.mapped_triples
else:
e_offset = get_num_ids(x[:, [0, 2]])
r_offset = get_num_ids(x[:, 1])
# append shifted mapped triples
res.append(x + offsets[None, i, [0, 1, 0]])
# update offsets
offsets[i + 1 :, 0] += e_offset
offsets[i + 1 :, 1] += r_offset
return torch.cat(res), offsets
def merge_label_to_id_mapping(
*pairs: tuple[str, Mapping[str, int]],
offsets: LongTensor | None = None,
mappings: Sequence[Mapping[int, int]] | None = None,
extra: Mapping[str, int] | None = None,
) -> dict[str, int]:
"""
Merge label-to-id mappings.
:param pairs:
pairs of `(prefix, label_to_id)`, where `label_to_id` is the label-to-id mapping, and `prefix` is a string
prefix to prepend to each key of the label-to-id mapping.
:param offsets: shape: `(len(pairs),)`
id offsets for each pair
:param mappings:
explicit id remappings for each pair
:param extra:
extra entries to add after merging
:return:
a merged label-to-id mapping
:raises ValueError:
if not exactly one of `offsets` or `mappings` is provided
"""
if (offsets is None and mappings is None) or (offsets is not None and mappings is not None):
raise ValueError("Exactly one of `offsets` or `mappings` has to be provided")
# merge labels with same ID
value_to_keys: defaultdict[int, set[str]] = defaultdict(set)
for i, (prefix, mapping) in enumerate(pairs):
for key, value in mapping.items():
key = f"{prefix}:{key}"
if offsets is None:
# for mypy
assert mappings is not None
value = mappings[i][value]
else:
value = value + offsets[i].item()
value_to_keys[value].add(key)
if extra:
for k, v in extra.items():
value_to_keys[v].add(k)
# reconstruct label-to-id
result: dict[str, int] = {}
for value, keys in value_to_keys.items():
if len(keys) == 1:
key = list(keys)[0]
else:
key = str(set(keys))
result[key] = value
return result
def merge_label_to_id_mappings(
left: TriplesFactory,
right: TriplesFactory,
relation_offsets: LongTensor,
# optional
entity_offsets: LongTensor | None = None,
entity_mappings: Sequence[Mapping[int, int]] | None = None,
extra_relations: Mapping[str, int] | None = None,
) -> tuple[Mapping[str, int], Mapping[str, int]]:
"""
Merge entity-to-id and relation-to-id mappings.
:param left:
the left triples factory
:param right:
the right triples factory
:param relation_offsets: shape: (2,)
the relation offsets
:param entity_offsets: shape: (2,)
the entity offsets, if entities are shifted uniformly
:param entity_mappings:
explicit entity ID remappings
:param extra_relations:
additional relations, as a mapping from their label to IDs
:return:
the updated entity and relation to id mappings
"""
# merge entity mapping
entity_to_id = merge_label_to_id_mapping(
(EA_SIDE_LEFT, left.entity_to_id),
(EA_SIDE_RIGHT, right.entity_to_id),
offsets=entity_offsets,
mappings=entity_mappings,
)
# merge relation mapping
relation_to_id = merge_label_to_id_mapping(
(EA_SIDE_LEFT, left.relation_to_id),
(EA_SIDE_RIGHT, right.relation_to_id),
offsets=relation_offsets,
extra=extra_relations,
)
return entity_to_id, relation_to_id
def filter_map_alignment(
alignment: pandas.DataFrame,
left: CoreTriplesFactory,
right: CoreTriplesFactory,
entity_offsets: LongTensor,
) -> LongTensor:
"""
Convert dataframe with label or ID-based alignment.
:param alignment: columns: EA_SIDES
the dataframe with the alignment
:param left:
the triples factory of the left graph
:param right:
the triples factory of the right graph
:param entity_offsets: shape: (2,)
the entity offsets from old to new IDs
:return: shape: (2, num_alignments)
the ID-based alignment in new IDs
:raises ValueError:
if the datatype of the alignment data frame is imcompatible (neither string nor integer).
"""
# convert labels to IDs
for side, tf in zip(EA_SIDES, (left, right), strict=False):
if isinstance(tf, TriplesFactory) and is_string_dtype(alignment[side]):
logger.debug(f"Mapping label-based alignment for {side}")
# map labels, using -1 as fill-value for invalid labels
# we cannot drop them here, since the two columns need to stay aligned
alignment[side] = alignment[side].apply(tf.entity_to_id.get, args=(-1,))
if not is_numeric_dtype(alignment[side]):
raise ValueError(f"Invalid dype in alignment dataframe for side={side}: {alignment[side].dtype}")
# filter alignment
invalid_mask = (alignment.values < 0).any(axis=1) | (
alignment.values >= numpy.reshape(numpy.asarray([left.num_entities, right.num_entities]), newshape=(1, 2))
).any(axis=1)
if invalid_mask.any():
logger.warning(
f"Dropping {format_relative_comparison(part=invalid_mask.sum(), total=alignment.shape[0])} "
f"alignments due to invalid labels.",
)
alignment = alignment.loc[~invalid_mask]
# map alignment from old IDs to new IDs
return torch.as_tensor(alignment.to_numpy().T, dtype=torch.long) + entity_offsets.view(2, 1)
def swap_index_triples(
mapped_triples: MappedTriples,
dense_map: LongTensor,
index: TargetColumn,
) -> MappedTriples:
"""
Return triples where some indices in the index column are swapped.
:param mapped_triples: shape: (m, 3)
the id-based triples
:param dense_map: shape: (n,)
a dense map between IDs. Contains `-1` for missing entries
:param index:
the index, a number between 0 (incl) and 3 (excl).
:return:
all triples which contain a key at the `index` position, where the key has been replaced by the corresponding
value in the dense map
"""
# determine swapping partner
trans = dense_map[mapped_triples[:, index]]
# only keep triples where we have a swapping partner
mask = trans >= 0
mapped_triples = mapped_triples[mask].clone()
# replace by swapping partner
mapped_triples[:, index] = trans[mask]
return mapped_triples
[docs]
class ProcessedTuple(NamedTuple):
"""The result of processing a pair of triples factories."""
#: the merged id-based triples, shape: (n, 3)
mapped_triples: MappedTriples
#: the updated alignment, shape: (2, m)
alignment: LongTensor
#: additional keyword-based parameters for adjusting label-to-id mappings
translation_kwargs: Mapping[str, Any]
FactoryType = TypeVar("FactoryType", CoreTriplesFactory, TriplesFactory)
[docs]
class GraphPairCombinator(Generic[FactoryType], ABC):
"""A base class for combination of a graph pair into a single graph."""
[docs]
def __call__(
self,
left: FactoryType,
right: FactoryType,
alignment: pandas.DataFrame,
**kwargs,
) -> tuple[FactoryType, LongTensor]:
"""
Combine two graphs using the alignment information.
:param left:
the triples of the left graph
:param right:
the triples of the right graph
:param alignment: columns: LEFT | RIGHT
the alignment, i.e., pairs of matching entities
:param kwargs:
additional keyword-based parameters passed to :meth:`TriplesFactory.__init__`
:return:
a single triples factory comprising the joint graph, as well as a tensor of pairs of matching IDs.
The tensor of matching pairs has shape `(2, num_alignments)`, where `num_alignments` can also be 0.
"""
# concatenate triples
mapped_triples, offsets = cat_shift_triples(left, right)
# filter alignment and translate to IDs
alignment = filter_map_alignment(alignment=alignment, left=left, right=right, entity_offsets=offsets[:, 0])
# process
# TODO: restrict to only using training alignments?
mapped_triples, alignment, translation_kwargs = self.process(mapped_triples, alignment, offsets)
combined_factory: FactoryType
if isinstance(left, TriplesFactory) and isinstance(right, TriplesFactory):
# merge mappings
entity_to_id, relation_to_id = merge_label_to_id_mappings(
left=left,
right=right,
**translation_kwargs,
)
combined_factory = TriplesFactory(
mapped_triples=mapped_triples,
entity_to_id=entity_to_id,
relation_to_id=relation_to_id,
**kwargs,
)
else:
combined_factory = CoreTriplesFactory.create(mapped_triples=mapped_triples, **kwargs)
return combined_factory, alignment
[docs]
@abstractmethod
def process(
self,
mapped_triples: MappedTriples,
alignment: LongTensor,
offsets: LongTensor,
) -> ProcessedTuple:
"""
Process the combined mapped triples.
:param mapped_triples: shape: (n, 3)
the id-based merged triples
:param alignment: shape: (2, m)
the id-based entity alignment
:param offsets: shape: (2, 2)
the entity and relation offsets from merging
:return:
updated triples and alignment tensor, as well as parameters for updating label-to-id mappings
"""
raise NotImplementedError
[docs]
class DisjointGraphPairCombinator(GraphPairCombinator[FactoryType]):
"""This combinator keeps both graphs as disconnected components."""
# docstr-coverage: inherited
[docs]
def process(
self,
mapped_triples: MappedTriples,
alignment: LongTensor,
offsets: LongTensor,
) -> ProcessedTuple: # noqa: D102
return ProcessedTuple(
mapped_triples,
alignment,
dict(entity_offsets=offsets[:, 0], relation_offsets=offsets[:, 1]),
)
[docs]
class SwapGraphPairCombinator(GraphPairCombinator[FactoryType]):
"""Add extra triples by swapping aligned entities."""
# docstr-coverage: inherited
[docs]
def process(
self,
mapped_triples: MappedTriples,
alignment: LongTensor,
offsets: LongTensor,
) -> ProcessedTuple: # noqa: D102
# add swap triples
# e1 ~ e2 => (e1, r, t) ~> (e2, r, t), or (h, r, e1) ~> (h, r, e2)
# create dense entity remapping for swap
dense_map = torch.full(size=(get_num_ids(mapped_triples[:, 0::2]),), fill_value=-1)
left_id, right_id = alignment
dense_map[left_id] = right_id
dense_map[right_id] = left_id
# add swapped triples
mapped_triples = torch.cat(
[
mapped_triples,
# swap head
swap_index_triples(mapped_triples=mapped_triples, dense_map=dense_map, index=COLUMN_HEAD),
# swap tail
swap_index_triples(mapped_triples=mapped_triples, dense_map=dense_map, index=COLUMN_TAIL),
],
dim=0,
)
return ProcessedTuple(
mapped_triples,
alignment,
dict(entity_offsets=offsets[:, 0], relation_offsets=offsets[:, 1]),
)
def iter_entity_mappings(
*old_new_ids_pairs: tuple[LongTensor, LongTensor], offsets: LongTensor
) -> Iterable[Mapping[int, int]]:
"""
Create explicit Id mappings.
:param old_new_ids_pairs:
aligned pairs of old and new ids
:param offsets: shape: (2,)
the entity offsets
:yields: explicit id remappings
"""
old, new = (torch.cat(tensors, dim=0) for tensors in zip(*old_new_ids_pairs, strict=False))
offsets = offsets.tolist() + [get_num_ids(old)]
for low, high in zip(offsets, offsets[1:], strict=False):
mask = (low <= old) & (old < high)
this_old = old[mask] - low
this_new = new[mask]
yield dict(zip(this_old.tolist(), this_new.tolist(), strict=False))
[docs]
class CollapseGraphPairCombinator(GraphPairCombinator[FactoryType]):
"""This combinator merges all matching entity pairs into a single ID."""
# docstr-coverage: inherited
[docs]
def process(
self,
mapped_triples: MappedTriples,
alignment: LongTensor,
offsets: LongTensor,
) -> ProcessedTuple: # noqa: D102
# determine connected components regarding the same-as relation (i.e., applies transitivity)
entity_id_mapping = torch.arange(get_num_ids(mapped_triples[:, 0::2]))
for cc in get_connected_components(pairs=alignment.t().tolist()):
cc = list(cc)
entity_id_mapping[cc] = min(cc)
# apply id mapping
h, r, t = mapped_triples.t()
h_new, t_new = entity_id_mapping[h], entity_id_mapping[t]
# ensure consecutive IDs
inverse = torch.cat([h_new, t_new]).unique(return_inverse=True)[1]
h_new, t_new = inverse.split((len(h), len(t)))
mapped_triples = torch.stack([h_new, r, t_new], dim=-1)
# only use training alignments?
return ProcessedTuple(
mapped_triples,
torch.empty(size=(2, 0), dtype=torch.long),
dict(
entity_mappings=list(iter_entity_mappings((h, h_new), (t, t_new), offsets=offsets[:, 0])),
relation_offsets=offsets[:, 1],
),
)
graph_combinator_resolver: ClassResolver[GraphPairCombinator] = ClassResolver.from_subclasses(
base=GraphPairCombinator,
default=ExtraRelationGraphPairCombinator,
)