Source code for pykeen.datasets.wk3l

# -*- coding: utf-8 -*-

"""The Wk3l-15k dataset family.

Get a summary with ``python -m pykeen.datasets.wk3l``
"""

import logging
import pathlib
import zipfile
from abc import ABC
from typing import ClassVar, Mapping, Optional, Tuple, cast

import click
import pandas
from docdata import parse_docdata
from more_click import verbose_option
from pystow.utils import download_from_google

from .base import LazyDataset
from ..triples import TriplesFactory
from ..typing import LABEL_HEAD, LABEL_RELATION, LABEL_TAIL, TorchRandomHint

__all__ = [
    "MTransEDataset",
    "WK3l15k",
    "CN3l",
    "WK3l120k",
]

logger = logging.getLogger(__name__)

GOOGLE_DRIVE_ID = "1AsPPU4ka1Rc9u-XYMGWtvV65hF3egi0z"
GRAPH_PAIRS = ("en_fr", "en_de")


class MTransEDataset(LazyDataset, ABC):
    """Base class for WK3l datasets (WK3l-15k, WK3l-120k, CN3l)."""

    #: The mapping from (graph-pair, side) to triple file name
    FILE_NAMES: ClassVar[Mapping[Tuple[str, str], str]]

    #: The internal dataset name
    DATASET_NAME: ClassVar[str]

    #: The hex digest for the zip file
    SHA512: str = (
        "b5b64db8acec2ef83a418008e8ff6ddcd3ea1db95a0a158825ea9cffa5a3c34a"
        "9aba6945674304f8623ab21c7248fed900028e71ad602883a307364b6e3681dc"
    )

    def __init__(
        self,
        graph_pair: str = "en_de",
        side: str = "en",
        cache_root: Optional[str] = None,
        eager: bool = False,
        create_inverse_triples: bool = False,
        random_state: TorchRandomHint = 0,
        split_ratios: Tuple[float, float, float] = (0.8, 0.1, 0.1),
        force: bool = False,
    ):
        """
        Initialize the dataset.

        :param graph_pair:
            The graph-pair within the dataset family (cf. GRAPH_PAIRS).
        :param side:
            The side of the graph-pair, a substring of the graph-pair selection.
        :param cache_root:
            The cache root.
        :param eager:
            Whether to directly load the dataset, or defer it to the first access of a relevant attribute.
        :param create_inverse_triples:
            Whether to create inverse triples.
        :param random_state:
            The random state used for splitting.
        :param split_ratios:
            The split ratios used for splitting the dataset into train / validation / test.
        :param force:
            Whether to enforce re-download of existing files.

        :raises ValueError:
            If the graph pair or side is invalid.
        """
        # Input validation.
        if graph_pair not in GRAPH_PAIRS:
            raise ValueError(f"Invalid graph pair: Allowed are: {GRAPH_PAIRS}")
        available_sides = graph_pair.split("_")
        if side not in available_sides:
            raise ValueError(f"side must be one of {available_sides}")

        self._relative_path = pathlib.PurePosixPath(
            "data",
            self.DATASET_NAME,
            graph_pair,
            self.FILE_NAMES[graph_pair, side],
        )

        # For downloading
        self.drive_id = GOOGLE_DRIVE_ID
        self.force = force
        self.cache_root = self._help_cache(cache_root)

        # For splitting
        self.random_state = random_state
        self.ratios = split_ratios

        # Whether to create inverse triples
        self.create_inverse_triples = create_inverse_triples

        if eager:
            self._load()

    def _extend_cache_root(self, cache_root: pathlib.Path) -> pathlib.Path:  # noqa: D102
        # shared directory for multiple datasets.
        return cache_root.joinpath("wk3l")

    def _load(self) -> None:
        path = self.cache_root.joinpath("data.zip")

        # ensure file is present
        # TODO: Re-use ensure_from_google?
        if not path.is_file() or self.force:
            logger.info(f"Downloading file from Google Drive (ID: {self.drive_id})")
            download_from_google(self.drive_id, path, hexdigests=dict(sha512=self.SHA512))

        # read all triples from file
        with zipfile.ZipFile(path) as zf:
            logger.info(f"Reading from {path.as_uri()}")
            with zf.open(str(self._relative_path), mode="r") as triples_file:
                df = pandas.read_csv(
                    triples_file,
                    delimiter="@@@",
                    header=None,
                    names=[LABEL_HEAD, LABEL_RELATION, LABEL_TAIL],
                    engine="python",
                    encoding="utf8",
                )
        # some "entities" have numeric labels
        # pandas.read_csv(..., dtype=str) does not work properly.
        df = df.astype(dtype=str)

        # create triples factory
        tf = TriplesFactory.from_labeled_triples(
            triples=df.values,
            create_inverse_triples=self.create_inverse_triples,
            metadata=dict(path=path),
        )

        # split
        self._training, self._testing, self._validation = cast(
            Tuple[TriplesFactory, TriplesFactory, TriplesFactory],
            tf.split(
                ratios=self.ratios,
                random_state=self.random_state,
            ),
        )
        logger.info("[%s] done splitting data from %s", self.__class__.__name__, path)

    def _load_validation(self) -> None:
        pass  # already loaded by _load()


[docs]@parse_docdata class WK3l15k(MTransEDataset): """The WK3l-15k dataset family. .. note :: This dataset contains artifacts from incorrectly treating literals as entities. --- name: WK3l-15k Family citation: author: Chen year: 2017 link: https://www.ijcai.org/Proceedings/2017/0209.pdf single: true statistics: entities: 15126 relations: 1841 triples: 209041 training: 167232 testing: 20904 validation: 20905 """ DATASET_NAME = "WK3l-15k" FILE_NAMES = { ("en_de", "en"): "P_en_v6.csv", ("en_de", "de"): "P_de_v6.csv", ("en_fr", "en"): "P_en_v5.csv", ("en_fr", "fr"): "P_fr_v5.csv", }
[docs]@parse_docdata class WK3l120k(MTransEDataset): """The WK3l-120k dataset family. .. note :: This dataset contains artifacts from incorrectly treating literals as entities. --- name: WK3l-120k Family citation: author: Chen year: 2017 link: https://www.ijcai.org/Proceedings/2017/0209.pdf single: true statistics: entities: 119748 relations: 3109 triples: 1375406 training: 499727 testing: 62466 validation: 62466 """ DATASET_NAME = "WK3l-120k" FILE_NAMES = { ("en_de", "en"): "P_en_v6_120k.csv", ("en_de", "de"): "P_de_v6_120k.csv", ("en_fr", "en"): "P_en_v5_120k.csv", ("en_fr", "fr"): "P_fr_v5_120k.csv", }
[docs]@parse_docdata class CN3l(MTransEDataset): """The CN3l dataset family. --- name: CN3l Family citation: author: Chen year: 2017 link: https://www.ijcai.org/Proceedings/2017/0209.pdf single: true statistics: entities: 3206 relations: 42 triples: 21777 training: 23492 testing: 2936 validation: 2937 """ DATASET_NAME = "CN3l" FILE_NAMES = { ("en_de", "en"): "C_en_d.csv", ("en_de", "de"): "C_de.csv", ("en_fr", "en"): "C_en_f.csv", ("en_fr", "fr"): "C_fr.csv", }
@click.command() @verbose_option def _main(): for graph_pair in GRAPH_PAIRS: for side in graph_pair.split("_"): for cls in (WK3l15k, WK3l120k, CN3l): ds = cls(graph_pair=graph_pair, side=side) ds.summarize() if __name__ == "__main__": _main()