"""Clinical Knowledge Graph."""
import pathlib
import tarfile
from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from urllib.request import urlretrieve
import click
import pandas as pd
from docdata import parse_docdata
from more_click import verbose_option
from .base import TabbedDataset
from ..typing import TorchRandomHint
__all__ = [
"CKG",
]
URL = "https://md-datasets-public-files-prod.s3.eu-west-1.amazonaws.com/d1e8d3df-2342-468a-91a9-97a981a479ad"
COLUMNS = ["START_ID", "TYPE", "END_ID"]
[docs]
@parse_docdata
class CKG(TabbedDataset):
"""The Clinical Knowledge Graph (CKG) dataset from [santos2020]_.
---
name: Clinical Knowledge Graph
citation:
author: Santos
year: 2020
link: https://doi.org/10.1101/2020.05.09.084897
github: MannLabs/CKG
single: true
statistics:
entities: 7617419
relations: 11
triples: 26691525
training: 21353220
testing: 2669152
validation: 2669153
"""
def __init__(
self,
random_state: TorchRandomHint = 0,
**kwargs,
):
"""Initialize the `CKG <https://github.com/MannLabs/CKG>`_ dataset from [santos2020]_.
:param random_state: The random seed to use in splitting the dataset. Defaults to 0.
:param kwargs: keyword arguments passed to :class:`pykeen.datasets.base.TabbedDataset`.
"""
super().__init__(
random_state=random_state,
**kwargs,
)
self.preloaded_path = self.cache_root.joinpath("preloaded.tsv.gz")
def _get_path(self) -> Optional[pathlib.Path]:
return self.preloaded_path
def _get_df(self) -> pd.DataFrame:
if self.preloaded_path.exists():
return pd.read_csv(self.preloaded_path, sep="\t", dtype=str)
df = pd.concat(self._iterate_dataframes())
df.to_csv(self.preloaded_path, sep="\t", index=False)
return df
def _iterate_dataframes(self) -> Iterable[pd.DataFrame]:
archive_path = self.cache_root / "data.tar.gz"
if not archive_path.exists():
urlretrieve(URL, archive_path) # noqa:S310
with tarfile.TarFile.open(archive_path) as tar_file:
if tar_file is None:
raise ValueError
for tarinfo in tar_file:
if not tarinfo.name.startswith("data/imports/") or not tarinfo.name.endswith(".tsv"):
continue
path = Path(tarinfo.name)
if path.name.startswith("."):
continue
_inner_file = tar_file.extractfile(tarinfo)
if _inner_file is None:
raise ValueError(f"Unable to open inner file: {tarinfo}")
with _inner_file as file:
df = pd.read_csv(file, usecols=COLUMNS, sep="\t", dtype=str)
df = df[COLUMNS]
yield df
@click.command()
@verbose_option
def _main():
from pykeen.datasets import get_dataset
d = get_dataset(dataset=CKG)
d.summarize()
if __name__ == "__main__":
_main()