"""Utilities for generating triples."""
import torch
from .triples_factory import CoreTriplesFactory
from .utils import get_entities, get_relations
from ..typing import MappedTriples, TorchRandomHint
from ..utils import ensure_torch_random_state
__all__ = [
"generate_triples",
"generate_triples_factory",
]
[docs]
def generate_triples(
num_entities: int = 33,
num_relations: int = 7,
num_triples: int = 101,
compact: bool = True,
random_state: TorchRandomHint = None,
) -> MappedTriples:
"""Generate random triples in a torch tensor."""
random_state = ensure_torch_random_state(random_state)
rv = torch.stack(
[
torch.randint(num_entities, size=(num_triples,), generator=random_state),
torch.randint(num_relations, size=(num_triples,), generator=random_state),
torch.randint(num_entities, size=(num_triples,), generator=random_state),
],
dim=1,
)
# ensure that each entity & relation occurs at least once
idx = torch.randperm(num_triples)[:num_entities]
rv[idx, 0] = torch.arange(num_entities)
idx = torch.randperm(num_triples)[:num_relations]
rv[idx, 1] = torch.arange(num_relations)
if compact:
new_entity_id = {entity: i for i, entity in enumerate(sorted(get_entities(rv)))}
new_relation_id = {relation: i for i, relation in enumerate(sorted(get_relations(rv)))}
rv = torch.as_tensor(
data=[[new_entity_id[h], new_relation_id[r], new_entity_id[t]] for h, r, t in rv.tolist()],
dtype=torch.long,
)
return rv
[docs]
def generate_triples_factory(
num_entities: int = 33,
num_relations: int = 7,
num_triples: int = 101,
random_state: TorchRandomHint = None,
create_inverse_triples: bool = False,
) -> CoreTriplesFactory:
"""Generate a triples factory with random triples."""
mapped_triples = generate_triples(
num_entities=num_entities,
num_relations=num_relations,
num_triples=num_triples,
random_state=random_state,
)
return CoreTriplesFactory.create(
mapped_triples=mapped_triples,
create_inverse_triples=create_inverse_triples,
)