Source code for pykeen.datasets.ckg

# -*- coding: utf-8 -*-

"""Clinical Knowledge Graph."""

import pathlib
import tarfile
from pathlib import Path
from typing import Iterable, Optional
from urllib.request import urlretrieve

import click
import pandas as pd
from docdata import parse_docdata
from more_click import verbose_option

from .base import TabbedDataset
from ..typing import TorchRandomHint

__all__ = [
    "CKG",
]

URL = "https://md-datasets-public-files-prod.s3.eu-west-1.amazonaws.com/d1e8d3df-2342-468a-91a9-97a981a479ad"
COLUMNS = ["START_ID", "TYPE", "END_ID"]


[docs]@parse_docdata class CKG(TabbedDataset): """The Clinical Knowledge Graph (CKG) dataset from [santos2020]_. --- name: Clinical Knowledge Graph citation: author: Santos year: 2020 link: https://doi.org/10.1101/2020.05.09.084897 github: MannLabs/CKG single: true statistics: entities: 7617419 relations: 11 triples: 26691525 training: 21353220 testing: 2669152 validation: 2669153 """ def __init__( self, random_state: TorchRandomHint = 0, **kwargs, ): """Initialize the `CKG <https://github.com/MannLabs/CKG>`_ dataset from [santos2020]_. :param random_state: The random seed to use in splitting the dataset. Defaults to 0. :param kwargs: keyword arguments passed to :class:`pykeen.datasets.base.TabbedDataset`. """ super().__init__( random_state=random_state, **kwargs, ) self.preloaded_path = self.cache_root.joinpath("preloaded.tsv.gz") def _get_path(self) -> Optional[pathlib.Path]: return self.preloaded_path def _get_df(self) -> pd.DataFrame: if self.preloaded_path.exists(): return pd.read_csv(self.preloaded_path, sep="\t", dtype=str) df = pd.concat(self._iterate_dataframes()) df.to_csv(self.preloaded_path, sep="\t", index=False) return df def _iterate_dataframes(self) -> Iterable[pd.DataFrame]: archive_path = self.cache_root / "data.tar.gz" if not archive_path.exists(): urlretrieve(URL, archive_path) # noqa:S310 with tarfile.TarFile.open(archive_path) as tar_file: if tar_file is None: raise ValueError for tarinfo in tar_file: if not tarinfo.name.startswith("data/imports/") or not tarinfo.name.endswith(".tsv"): continue path = Path(tarinfo.name) if path.name.startswith("."): continue _inner_file = tar_file.extractfile(tarinfo) if _inner_file is None: raise ValueError(f"Unable to open inner file: {tarinfo}") with _inner_file as file: df = pd.read_csv(file, usecols=COLUMNS, sep="\t", dtype=str) df = df[COLUMNS] yield df
@click.command() @verbose_option def _main(): from pykeen.datasets import get_dataset d = get_dataset(dataset=CKG) d.summarize() if __name__ == "__main__": _main()