Source code for pykeen.datasets.ckg

# -*- coding: utf-8 -*-

"""Clinical Knowledge Graph."""

import tarfile
from pathlib import Path
from typing import Iterable, Optional
from urllib.request import urlretrieve

import click
import pandas as pd

from .base import TabbedDataset
from ..typing import TorchRandomHint

__all__ = [
    'CKG',
]

URL = 'https://md-datasets-public-files-prod.s3.eu-west-1.amazonaws.com/d1e8d3df-2342-468a-91a9-97a981a479ad'
COLUMNS = ['START_ID', 'TYPE', 'END_ID']


[docs]class CKG(TabbedDataset): """The Clinical Knowledge Graph (CKG) dataset from [santos2020]_. This dataset contains ~7.6 million nodes, 11 relations, and ~26 million triples. .. [santos2020] Santos, A., *et al* (2020). `Clinical Knowledge Graph Integrates Proteomics Data into Clinical Decision-Making <https://doi.org/10.1101/2020.05.09.084897>`_. *bioRxiv*, 2020.05.09.084897. """ def __init__( self, create_inverse_triples: bool = False, random_state: TorchRandomHint = 0, **kwargs, ): """Initialize the `CKG <https://github.com/MannLabs/CKG>`_ dataset from [santos2020]_. :param create_inverse_triples: Should inverse triples be created? Defaults to false. :param random_state: The random seed to use in splitting the dataset. Defaults to 0. :param kwargs: keyword arguments passed to :class:`pykeen.datasets.base.TabbedDataset`. """ super().__init__( create_inverse_triples=create_inverse_triples, random_state=random_state, **kwargs, ) self.preloaded_path = self.cache_root / 'preloaded.tsv.gz' def _get_path(self) -> Optional[str]: return self.preloaded_path.as_posix() def _get_df(self) -> pd.DataFrame: if self.preloaded_path.exists(): return pd.read_csv(self.preloaded_path, sep='\t') df = pd.concat(self._iterate_dataframes()) df.to_csv(self.preloaded_path, sep='\t', index=False) return df def _iterate_dataframes(self) -> Iterable[pd.DataFrame]: archive_path = self.cache_root / 'data.tar.gz' if not archive_path.exists(): urlretrieve(URL, archive_path) # noqa:S310 with tarfile.TarFile.open(archive_path) as tar_file: if tar_file is None: raise ValueError for tarinfo in tar_file: if not tarinfo.name.startswith('data/imports/') or not tarinfo.name.endswith('.tsv'): continue path = Path(tarinfo.name) if path.name.startswith('.'): continue _inner_file = tar_file.extractfile(tarinfo) if _inner_file is None: raise ValueError(f'Unable to open inner file: {tarinfo}') with _inner_file as file: df = pd.read_csv(file, usecols=COLUMNS, sep='\t', dtype=str) df = df[COLUMNS] yield df
@click.command() def _main(): d = CKG() d.summarize() if __name__ == '__main__': _main()