# -*- coding: utf-8 -*-
"""Instance creation utilities."""
from typing import Callable, Mapping, Optional, Sequence, Set, TextIO, Union
import numpy as np
import torch
from pkg_resources import iter_entry_points
from ..typing import LabeledTriples
__all__ = [
'load_triples',
'get_entities',
'get_relations',
]
def _load_importers(group_subname: str) -> Mapping[str, Callable[[str], LabeledTriples]]:
return {
entry_point.name: entry_point.load()
for entry_point in iter_entry_points(group=f'pykeen.triples.{group_subname}')
}
#: Functions for specifying exotic resources with a given prefix
PREFIX_IMPORTERS: Mapping[str, Callable[[str], LabeledTriples]] = _load_importers('prefix_importer')
#: Functions for specifying exotic resources based on their file extension
EXTENSION_IMPORTERS: Mapping[str, Callable[[str], LabeledTriples]] = _load_importers('extension_importer')
[docs]def load_triples(
path: Union[str, TextIO],
delimiter: str = '\t',
encoding: Optional[str] = None,
column_remapping: Optional[Sequence[int]] = None,
) -> LabeledTriples:
"""Load triples saved as tab separated values.
:param path: The key for the data to be loaded. Typically, this will be a file path ending in ``.tsv``
that points to a file with three columns - the head, relation, and tail. This can also be used to
invoke PyKEEN data importer entrypoints (see below).
:param delimiter: The delimiter between the columns in the file
:param encoding: The encoding for the file. Defaults to utf-8.
:param column_remapping: A remapping if the three columns do not follow the order head-relation-tail.
For example, if the order is head-tail-relation, pass ``(0, 2, 1)``
:returns: A numpy array representing "labeled" triples.
:raises ValueError: if a column remapping was passed but it was not a length 3 sequence
Besides TSV handling, PyKEEN does not come with any importers pre-installed. A few can be found at:
- :mod:`pybel.io.pykeen`
- :mod:`bio2bel.io.pykeen`
"""
if isinstance(path, str):
for extension, handler in EXTENSION_IMPORTERS.items():
if path.endswith(f'.{extension}'):
return handler(path)
for prefix, handler in PREFIX_IMPORTERS.items():
if path.startswith(f'{prefix}:'):
return handler(path[len(f'{prefix}:'):])
if encoding is None:
encoding = 'utf-8'
rv = np.loadtxt(
fname=path,
dtype=str,
comments='@Comment@ Head Relation Tail',
delimiter=delimiter,
encoding=encoding,
)
if column_remapping is not None:
if len(column_remapping) != 3:
raise ValueError('remapping must have length of three')
rv = rv[:, column_remapping]
return rv
[docs]def get_entities(triples: torch.LongTensor) -> Set[int]:
"""Get all entities from the triples."""
return set(triples[:, [0, 2]].flatten().tolist())
[docs]def get_relations(triples: torch.LongTensor) -> Set[int]:
"""Get all relations from the triples."""
return set(triples[:, 1].tolist())