Source code for pykeen.datasets.ogb

# -*- coding: utf-8 -*-

"""Load the OGB datasets.

Run with python -m pykeen.datasets.ogb
"""

from typing import ClassVar, Optional

import numpy as np

from .base import LazyDataset
from ..triples import TriplesFactory

__all__ = [
    'OGBLoader',
    'OGBBioKG',
    'OGBWikiKG',
]


class OGBLoader(LazyDataset):
    """Load from the Open Graph Benchmark (OGB)."""

    #: The name of the dataset to download
    name: ClassVar[str]

    def __init__(self, cache_root: Optional[str] = None, create_inverse_triples: bool = False):
        self.cache_root = self._help_cache(cache_root)
        self.create_inverse_triples = create_inverse_triples

    def _load(self) -> None:
        try:
            from ogb.linkproppred import LinkPropPredDataset
        except ImportError as e:
            raise ModuleNotFoundError(
                f'Need to `pip install ogb` to use pykeen.datasets.{self.__class__.__name__}.',
            ) from e

        dataset = LinkPropPredDataset(name=self.name, root=self.cache_root)
        edge_split = dataset.get_edge_split()
        self._training = self._make_tf(edge_split["train"])
        self._testing = self._make_tf(
            edge_split["test"],
            entity_to_id=self._training.entity_to_id,
            relation_to_id=self._training.relation_to_id,
        )
        self._validation = self._make_tf(
            edge_split["valid"],
            entity_to_id=self._training.entity_to_id,
            relation_to_id=self._training.relation_to_id,
        )

    def _loaded_validation(self) -> bool:
        return self._loaded

    def _load_validation(self) -> None:
        pass

    def _make_tf(self, x, entity_to_id=None, relation_to_id=None):
        triples = np.stack([x['head'], x['relation'], x['tail']], axis=1)

        # FIXME these are already identifiers
        triples = triples.astype(np.str)

        return TriplesFactory.from_labeled_triples(
            triples=triples,
            create_inverse_triples=self.create_inverse_triples,
            entity_to_id=entity_to_id,
            relation_to_id=relation_to_id,
        )


[docs]class OGBBioKG(OGBLoader): """The OGB BioKG dataset. .. seealso:: https://ogb.stanford.edu/docs/linkprop/#ogbl-biokg """ name = 'ogbl-biokg'
[docs]class OGBWikiKG(OGBLoader): """The OGB WikiKG dataset. .. seealso:: https://ogb.stanford.edu/docs/linkprop/#ogbl-wikikg """ name = 'ogbl-wikikg'
if __name__ == '__main__': for _cls in [OGBBioKG, OGBWikiKG]: _cls().summarize()