Source code for pykeen.triples.instances

# -*- coding: utf-8 -*-

"""Implementation of basic instance factory which creates just instances based on standard KG triples."""

from dataclasses import dataclass
from typing import Mapping

import numpy as np
import scipy.sparse
from torch.utils import data

from ..typing import MappedTriples
from ..utils import fix_dataclass_init_docs

__all__ = [
    'Instances',
    'SLCWAInstances',
    'LCWAInstances',
    'MultimodalInstances',
    'MultimodalSLCWAInstances',
    'MultimodalLCWAInstances',
]


[docs]@fix_dataclass_init_docs @dataclass class Instances(data.Dataset): """Triples and mappings to their indices.""" def __len__(self): # noqa:D401 """The number of instances.""" raise NotImplementedError
[docs]@fix_dataclass_init_docs @dataclass class SLCWAInstances(Instances): """Triples and mappings to their indices for sLCWA.""" #: The mapped triples, shape: (num_triples, 3) mapped_triples: MappedTriples def __len__(self): # noqa: D105 return self.mapped_triples.shape[0] def __getitem__(self, item): # noqa: D105 return self.mapped_triples[item]
[docs]@fix_dataclass_init_docs @dataclass class LCWAInstances(Instances): """Triples and mappings to their indices for LCWA.""" #: The unique pairs pairs: np.ndarray #: The compressed triples in CSR format compressed: scipy.sparse.csr_matrix
[docs] @classmethod def from_triples(cls, mapped_triples: MappedTriples, num_entities: int) -> Instances: """ Create LCWA instances from triples. :param mapped_triples: shape: (num_triples, 3) The ID-based triples. :param num_entities: The number of entities. :return: The instances. """ mapped_triples = mapped_triples.numpy() unique_hr, pair_idx_to_triple_idx = np.unique(mapped_triples[:, :2], return_inverse=True, axis=0) num_pairs = unique_hr.shape[0] tails = mapped_triples[:, 2] compressed = scipy.sparse.coo_matrix( (np.ones(mapped_triples.shape[0], dtype=np.float32), (pair_idx_to_triple_idx, tails)), shape=(num_pairs, num_entities), ) # convert to csr for fast row slicing compressed = compressed.tocsr() return cls(pairs=unique_hr, compressed=compressed)
def __len__(self) -> int: # noqa: D105 return self.pairs.shape[0] def __getitem__(self, item): # noqa: D105 return self.pairs[item], np.asarray(self.compressed[item, :].todense())[0, :]
[docs]@fix_dataclass_init_docs @dataclass class MultimodalInstances(Instances): """Triples and mappings to their indices as well as multimodal data.""" #: TODO: do we need these? numeric_literals: Mapping[str, np.ndarray] literals_to_id: Mapping[str, int]
[docs]@fix_dataclass_init_docs @dataclass class MultimodalSLCWAInstances(SLCWAInstances, MultimodalInstances): """Triples and mappings to their indices as well as multimodal data for sLCWA."""
[docs]@fix_dataclass_init_docs @dataclass class MultimodalLCWAInstances(LCWAInstances, MultimodalInstances): """Triples and mappings to their indices as well as multimodal data for LCWA."""