# -*- coding: utf-8 -*-
"""Implementation of basic instance factory which creates just instances based on standard KG triples."""
from dataclasses import dataclass
from typing import Mapping
import numpy as np
import scipy.sparse
from torch.utils import data
from ..typing import MappedTriples
from ..utils import fix_dataclass_init_docs
__all__ = [
'Instances',
'SLCWAInstances',
'LCWAInstances',
'MultimodalInstances',
'MultimodalSLCWAInstances',
'MultimodalLCWAInstances',
]
[docs]@fix_dataclass_init_docs
@dataclass
class Instances(data.Dataset):
"""Triples and mappings to their indices."""
def __len__(self): # noqa:D401
"""The number of instances."""
raise NotImplementedError
[docs]@fix_dataclass_init_docs
@dataclass
class SLCWAInstances(Instances):
"""Triples and mappings to their indices for sLCWA."""
#: The mapped triples, shape: (num_triples, 3)
mapped_triples: MappedTriples
def __len__(self): # noqa: D105
return self.mapped_triples.shape[0]
def __getitem__(self, item): # noqa: D105
return self.mapped_triples[item]
[docs]@fix_dataclass_init_docs
@dataclass
class LCWAInstances(Instances):
"""Triples and mappings to their indices for LCWA."""
#: The unique pairs
pairs: np.ndarray
#: The compressed triples in CSR format
compressed: scipy.sparse.csr_matrix
[docs] @classmethod
def from_triples(cls, mapped_triples: MappedTriples, num_entities: int) -> Instances:
"""
Create LCWA instances from triples.
:param mapped_triples: shape: (num_triples, 3)
The ID-based triples.
:param num_entities:
The number of entities.
:return:
The instances.
"""
mapped_triples = mapped_triples.numpy()
unique_hr, pair_idx_to_triple_idx = np.unique(mapped_triples[:, :2], return_inverse=True, axis=0)
num_pairs = unique_hr.shape[0]
tails = mapped_triples[:, 2]
compressed = scipy.sparse.coo_matrix(
(np.ones(mapped_triples.shape[0], dtype=np.float32), (pair_idx_to_triple_idx, tails)),
shape=(num_pairs, num_entities),
)
# convert to csr for fast row slicing
compressed = compressed.tocsr()
return cls(pairs=unique_hr, compressed=compressed)
def __len__(self) -> int: # noqa: D105
return self.pairs.shape[0]
def __getitem__(self, item): # noqa: D105
return self.pairs[item], np.asarray(self.compressed[item, :].todense())[0, :]
[docs]@fix_dataclass_init_docs
@dataclass
class MultimodalInstances(Instances):
"""Triples and mappings to their indices as well as multimodal data."""
#: TODO: do we need these?
numeric_literals: Mapping[str, np.ndarray]
literals_to_id: Mapping[str, int]
[docs]@fix_dataclass_init_docs
@dataclass
class MultimodalSLCWAInstances(SLCWAInstances, MultimodalInstances):
"""Triples and mappings to their indices as well as multimodal data for sLCWA."""
[docs]@fix_dataclass_init_docs
@dataclass
class MultimodalLCWAInstances(LCWAInstances, MultimodalInstances):
"""Triples and mappings to their indices as well as multimodal data for LCWA."""