Source code for pykeen.hpo.hpo

# -*- coding: utf-8 -*-

"""Hyper-parameter optimiziation in PyKEEN."""

import dataclasses
import ftplib
import inspect
import json
import logging
import os
import pathlib
from dataclasses import dataclass
from typing import Any, Callable, Collection, Dict, Iterable, Mapping, Optional, Type, Union, cast

import torch
from class_resolver.contrib.optuna import pruner_resolver, sampler_resolver
from optuna import Study, Trial, TrialPruned, create_study
from optuna.pruners import BasePruner
from optuna.samplers import BaseSampler
from optuna.storages import BaseStorage

from ..constants import USER_DEFINED_CODE
from ..datasets import dataset_resolver, has_dataset
from ..datasets.base import Dataset
from ..evaluation import Evaluator, evaluator_resolver
from ..losses import Loss, loss_resolver
from ..lr_schedulers import LRScheduler, lr_scheduler_resolver, lr_schedulers_hpo_defaults
from ..models import Model, model_resolver
from ..optimizers import Optimizer, optimizer_resolver, optimizers_hpo_defaults
from ..pipeline import pipeline, replicate_pipeline_from_config
from ..regularizers import Regularizer, regularizer_resolver
from ..sampling import NegativeSampler, negative_sampler_resolver
from ..stoppers import EarlyStopper, Stopper, stopper_resolver
from ..trackers import ResultTracker, tracker_resolver
from import SLCWATrainingLoop, TrainingLoop, training_loop_resolver
from ..triples import CoreTriplesFactory
from ..typing import Hint, HintType
from ..utils import Result, ensure_ftp_directory, fix_dataclass_init_docs, get_df_io, get_json_bytes_io, normalize_path
from ..version import get_git_hash, get_version

__all__ = [

logger = logging.getLogger(__name__)

STOPPED_EPOCH_KEY = "stopped_epoch"

class ExtraKeysError(ValueError):
    """Raised on extra keys being used."""

    def __init__(self, keys: Iterable[str]):
        Initialize the error.

        :param keys:
            the extra keys

    def __str__(self) -> str:
        return f"Invalid keys: {self.args[0]}"

class Objective:
    """A dataclass containing all of the information to make an objective function."""

    dataset: Union[None, str, Dataset, Type[Dataset]]  # 1.
    model: Type[Model]  # 2.
    loss: Type[Loss]  # 3.
    optimizer: Type[Optimizer]  # 5.
    training_loop: Type[TrainingLoop]  # 6.
    stopper: Type[Stopper]  # 7.
    evaluator: Type[Evaluator]  # 8.
    result_tracker: Union[ResultTracker, Type[ResultTracker]]  # 9.
    metric: str

    # 1. Dataset
    dataset_kwargs: Optional[Mapping[str, Any]] = None
    training: Hint[CoreTriplesFactory] = None
    testing: Hint[CoreTriplesFactory] = None
    validation: Hint[CoreTriplesFactory] = None
    evaluation_entity_whitelist: Optional[Collection[str]] = None
    evaluation_relation_whitelist: Optional[Collection[str]] = None
    # 2. Model
    model_kwargs: Optional[Mapping[str, Any]] = None
    model_kwargs_ranges: Optional[Mapping[str, Any]] = None
    # 3. Loss
    loss_kwargs: Optional[Mapping[str, Any]] = None
    loss_kwargs_ranges: Optional[Mapping[str, Any]] = None
    # 4. Regularizer
    regularizer: Optional[Type[Regularizer]] = None
    regularizer_kwargs: Optional[Mapping[str, Any]] = None
    regularizer_kwargs_ranges: Optional[Mapping[str, Any]] = None
    # 5. Optimizer
    optimizer_kwargs: Optional[Mapping[str, Any]] = None
    optimizer_kwargs_ranges: Optional[Mapping[str, Any]] = None
    # 5.1 Learning Rate Scheduler
    lr_scheduler: Optional[Type[LRScheduler]] = None
    lr_scheduler_kwargs: Optional[Mapping[str, Any]] = None
    lr_scheduler_kwargs_ranges: Optional[Mapping[str, Any]] = None
    # 6. Training Loop
    training_loop_kwargs: Optional[Mapping[str, Any]] = None
    negative_sampler: Optional[Type[NegativeSampler]] = None
    negative_sampler_kwargs: Optional[Mapping[str, Any]] = None
    negative_sampler_kwargs_ranges: Optional[Mapping[str, Any]] = None
    # 7. Training
    training_kwargs: Optional[Mapping[str, Any]] = None
    training_kwargs_ranges: Optional[Mapping[str, Any]] = None
    stopper_kwargs: Optional[Mapping[str, Any]] = None
    # 8. Evaluation
    evaluator_kwargs: Optional[Mapping[str, Any]] = None
    evaluation_kwargs: Optional[Mapping[str, Any]] = None
    filter_validation_when_testing: bool = True
    # 9. Trackers
    result_tracker_kwargs: Optional[Mapping[str, Any]] = None
    # Misc.
    device: Union[None, str, torch.device] = None
    save_model_directory: Optional[str] = None

    def _update_stopper_callbacks(
        stopper_kwargs: Dict[str, Any],
        trial: Trial,
        metric: str,
        result_tracker: ResultTracker,
    ) -> None:
        """Make a subclass of the EarlyStopper that reports to the trial."""

        def _result_callback(_early_stopper: EarlyStopper, result: Union[float, int], epoch: int) -> None:
  , step=epoch)
            if trial.should_prune():
                # log pruning
                result_tracker.log_metrics(metrics=dict(pruned=1), step=epoch)
                # trial was successful, but has to be ended
                # also show info
      "Pruned trial: {trial} at epoch {epoch} due to {metric}={result}")
                raise TrialPruned()

        def _stopped_callback(_early_stopper: EarlyStopper, _result: Union[float, int], epoch: int) -> None:
            trial.set_user_attr(STOPPED_EPOCH_KEY, epoch)

        for key, callback in zip(("result_callbacks", "stopped_callbacks"), (_result_callback, _stopped_callback)):
            stopper_kwargs.setdefault(key, []).append(callback)

    def __call__(self, trial: Trial) -> Optional[float]:
        """Suggest parameters then train the model."""
        if self.model_kwargs is not None:
            problems = [
                for x in ("loss", "regularizer", "optimizer", "lr_scheduler", "training", "negative_sampler", "stopper")
                if x in self.model_kwargs
            if problems:
                raise ValueError(f"model_kwargs should not have: {problems}. {self}")

        # 2. Model
        _model_kwargs = _get_kwargs(

            loss_default_kwargs_ranges = self.loss.hpo_default
        except AttributeError:
            logger.warning("using a loss function with no hpo_default field: %s", self.loss)
            loss_default_kwargs_ranges = {}

        # 3. Loss
        _loss_kwargs = _get_kwargs(
        # 4. Regularizer
        _regularizer_kwargs: Optional[Mapping[str, Any]]
        if self.regularizer is None:
            _regularizer_kwargs = {}
            _regularizer_kwargs = _get_kwargs(
        # 5. Optimizer
        _optimizer_kwargs = _get_kwargs(
        # 5.1 Learning Rate Scheduler
        _lr_scheduler_kwargs: Optional[Mapping[str, Any]] = None
        if self.lr_scheduler is not None:
            _lr_scheduler_kwargs = _get_kwargs(

        _negative_sampler_kwargs: Mapping[str, Any]
        if self.training_loop is not SLCWATrainingLoop:
            _negative_sampler_kwargs = {}
        elif self.negative_sampler is None:
            raise ValueError("Negative sampler class must be made explicit when training under sLCWA")
            # TODO this fixes the issue for negative samplers, but does not generally address it.
            #  For example, some of them obscure their arguments with **kwargs, so should we look
            #  at the parent class? Sounds like something to put in class resolver by using the
            #  inspect module. For now, this solution will rely on the fact that the sampler is a
            #  direct descendent of a parent NegativeSampler
            direct_params = inspect.signature(self.negative_sampler).parameters
            parent_params = inspect.signature(self.negative_sampler.__bases__[0]).parameters
            valid_keys = set(direct_params).union(parent_params) - {"kwargs"}
            invalid_keys = set(self.negative_sampler_kwargs_ranges or []) - valid_keys
            if invalid_keys:
                raise ExtraKeysError(invalid_keys)
            _negative_sampler_kwargs = _get_kwargs(

        _training_kwargs = _get_kwargs(

        # a fixed checkpoint_name leads avoid collision across trials
        checkpoint_name = _training_kwargs.get("checkpoint_name", None)
        if checkpoint_name:
            raise ValueError(
                f"Cannot set a fixed {checkpoint_name=} across all trials; if you want to save the final model per "
                f"trial, use `save_model_directory` instead!",

        # create result tracker to allow to gracefully close failed trials
        result_tracker = tracker_resolver.make(query=self.result_tracker, pos_kwargs=self.result_tracker_kwargs)

        _stopper_kwargs = dict(self.stopper_kwargs or {})
        if self.stopper is not None and issubclass(self.stopper, EarlyStopper):
            self._update_stopper_callbacks(_stopper_kwargs, trial, metric=self.metric, result_tracker=result_tracker)

            result = pipeline(
                # 1. Dataset
                # 2. Model
                # 3. Loss
                # 4. Regularizer
                # 5. Optimizer
                # 5.1 Learning Rate Scheduler
                # 6. Training Loop
                # 7. Training
                # 8. Evaluation
                # 9. Tracker
                # Misc.
                use_testing_data=False,  # use validation set during HPO!
        except (MemoryError, RuntimeError) as e:
            # close run in result tracker
            # raise the error again (which will be catched in study.optimize)
            raise e
            if self.save_model_directory:
                model_directory = os.path.join(self.save_model_directory, str(trial.number))
                os.makedirs(model_directory, exist_ok=True)

            trial.set_user_attr("random_seed", result.random_seed)

            for k, v in result.metric_results.to_flat_dict().items():
                trial.set_user_attr(k, v)

            return result.metric_results.get_metric(self.metric)

[docs]@fix_dataclass_init_docs @dataclass class HpoPipelineResult(Result): """A container for the results of the HPO pipeline.""" #: The :mod:`optuna` study object study: Study #: The objective class, containing information on preset hyper-parameters and those to optimize objective: Objective def _get_best_study_config(self): metadata = { "best_trial_number":, "best_trial_evaluation":, } pipeline_config = dict() for k, v in if k.startswith("pykeen_"): metadata[k[len("pykeen_") :]] = v elif k in {"metric"}: continue else: pipeline_config[k] = v for field in dataclasses.fields(self.objective): field_value = getattr(self.objective, if not field_value: continue if"_kwargs"): logger.debug(f"saving pre-specified field in pipeline config: {}={field_value}") pipeline_config[] = field_value elif == "result_tracker" and field_value: if issubclass(field_value, ResultTracker): tracker_subclass = tracker_resolver.normalize_cls(field_value) if not tracker_subclass: # field_value is base class continue pipeline_config[] = tracker_subclass else: logger.error(f"Invalid value for field {}: {field_value!r}") elif in {"training", "testing", "validation"}: pipeline_config[] = field_value if isinstance(field_value, str) else USER_DEFINED_CODE for k, v in sk, ssk = k.split(".") sk = f"{sk}_kwargs" if sk not in pipeline_config: pipeline_config[sk] = {} logger.debug(f"saving optimized field in pipeline config: {sk}.{ssk}={v}") pipeline_config[sk][ssk] = v for k in ("stopper", "stopper_kwargs"): if k in pipeline_config: v = pipeline_config.pop(k) metadata[f"_{k}_removed_comment"] = f"{k} config removed after HPO: {v}" stopped_epoch = if stopped_epoch is not None: old_num_epochs = pipeline_config["training_kwargs"]["num_epochs"] metadata["_stopper_comment"] = ( f"While the original config had {old_num_epochs}," f" early stopping will now switch it to {int(stopped_epoch)}" ) pipeline_config["training_kwargs"]["num_epochs"] = int(stopped_epoch) return dict(metadata=metadata, pipeline=pipeline_config)
[docs] def save_to_directory(self, directory: Union[str, pathlib.Path], **kwargs) -> None: """Dump the results of a study to the given directory.""" directory = normalize_path(directory, mkdir=True) # Output study information with directory.joinpath("study.json").open("w") as file: json.dump(, file, indent=2, sort_keys=True) # Output all trials df = df.to_csv(directory.joinpath("trials.tsv"), sep="\t", index=False) best_pipeline_directory = directory.joinpath("best_pipeline") best_pipeline_directory.mkdir(exist_ok=True, parents=True) # Output best trial as pipeline configuration file with best_pipeline_directory.joinpath("pipeline_config.json").open("w") as file: json.dump(self._get_best_study_config(), file, indent=2, sort_keys=True)
[docs] def save_to_ftp(self, directory: str, ftp: ftplib.FTP): """Save the results to the directory in an FTP server. :param directory: The directory in the FTP server to save to :param ftp: A connection to the FTP server """ ensure_ftp_directory(ftp=ftp, directory=directory) study_path = os.path.join(directory, "study.json") ftp.storbinary(f"STOR {study_path}", get_json_bytes_io( trials_path = os.path.join(directory, "trials.tsv") ftp.storbinary(f"STOR {trials_path}", get_df_io( best_pipeline_directory = os.path.join(directory, "best_pipeline") ensure_ftp_directory(ftp=ftp, directory=best_pipeline_directory) best_config_path = os.path.join(best_pipeline_directory, "pipeline_config.json") ftp.storbinary(f"STOR {best_config_path}", get_json_bytes_io(self._get_best_study_config()))
[docs] def save_to_s3(self, directory: str, bucket: str, s3=None) -> None: """Save all artifacts to the given directory in an S3 Bucket. :param directory: The directory in the S3 bucket :param bucket: The name of the S3 bucket :param s3: A client from :func:`boto3.client`, if already instantiated """ if s3 is None: import boto3 s3 = boto3.client("s3") study_path = os.path.join(directory, "study.json") s3.upload_fileobj(get_json_bytes_io(, bucket, study_path) trials_path = os.path.join(directory, "trials.tsv") s3.upload_fileobj(get_df_io(, bucket, trials_path) best_config_path = os.path.join(directory, "best_pipeline", "pipeline_config.json") s3.upload_fileobj(get_json_bytes_io(self._get_best_study_config()), bucket, best_config_path)
[docs] def replicate_best_pipeline( self, *, directory: Union[str, pathlib.Path], replicates: int, move_to_cpu: bool = False, save_replicates: bool = True, save_training: bool = False, ) -> None: """Run the pipeline on the best configuration, but this time on the "test" set instead of "evaluation" set. :param directory: Output directory :param replicates: The number of times to retrain the model :param move_to_cpu: Should the model be moved back to the CPU? Only relevant if training on GPU. :param save_replicates: Should the artifacts of the replicates be saved? :param save_training: Should the training triples be saved? :raises ValueError: if :data:`"use_testing_data"` is provided in the best pipeline's `config`. """ config = self._get_best_study_config() if "use_testing_data" in config: raise ValueError("use_testing_data not be set in the configuration at at all!") replicate_pipeline_from_config( config=config, directory=directory, replicates=replicates, use_testing_data=True, move_to_cpu=move_to_cpu, save_replicates=save_replicates, save_training=save_training, )
[docs]def hpo_pipeline_from_path(path: Union[str, pathlib.Path], **kwargs) -> HpoPipelineResult: """Run a HPO study from the configuration at the given path.""" with open(path) as file: config = json.load(file) return hpo_pipeline_from_config(config, **kwargs)
[docs]def hpo_pipeline_from_config(config: Mapping[str, Any], **kwargs) -> HpoPipelineResult: """Run the HPO pipeline using a properly formatted configuration dictionary.""" return hpo_pipeline( **config["pipeline"], **config["optuna"], **kwargs, )
[docs]def hpo_pipeline( *, # 1. Dataset dataset: Union[None, str, Dataset, Type[Dataset]] = None, dataset_kwargs: Optional[Mapping[str, Any]] = None, training: Hint[CoreTriplesFactory] = None, testing: Hint[CoreTriplesFactory] = None, validation: Hint[CoreTriplesFactory] = None, evaluation_entity_whitelist: Optional[Collection[str]] = None, evaluation_relation_whitelist: Optional[Collection[str]] = None, # 2. Model model: Union[str, Type[Model]], model_kwargs: Optional[Mapping[str, Any]] = None, model_kwargs_ranges: Optional[Mapping[str, Any]] = None, # 3. Loss loss: HintType[Loss] = None, loss_kwargs: Optional[Mapping[str, Any]] = None, loss_kwargs_ranges: Optional[Mapping[str, Any]] = None, # 4. Regularizer regularizer: HintType[Regularizer] = None, regularizer_kwargs: Optional[Mapping[str, Any]] = None, regularizer_kwargs_ranges: Optional[Mapping[str, Any]] = None, # 5. Optimizer optimizer: HintType[Optimizer] = None, optimizer_kwargs: Optional[Mapping[str, Any]] = None, optimizer_kwargs_ranges: Optional[Mapping[str, Any]] = None, # 5.1 Learning Rate Scheduler lr_scheduler: HintType[LRScheduler] = None, lr_scheduler_kwargs: Optional[Mapping[str, Any]] = None, lr_scheduler_kwargs_ranges: Optional[Mapping[str, Any]] = None, # 6. Training Loop training_loop: HintType[TrainingLoop] = None, training_loop_kwargs: Optional[Mapping[str, Any]] = None, negative_sampler: HintType[NegativeSampler] = None, negative_sampler_kwargs: Optional[Mapping[str, Any]] = None, negative_sampler_kwargs_ranges: Optional[Mapping[str, Any]] = None, # 7. Training epochs: Optional[int] = None, training_kwargs: Optional[Mapping[str, Any]] = None, training_kwargs_ranges: Optional[Mapping[str, Any]] = None, stopper: HintType[Stopper] = None, stopper_kwargs: Optional[Mapping[str, Any]] = None, # 8. Evaluation evaluator: HintType[Evaluator] = None, evaluator_kwargs: Optional[Mapping[str, Any]] = None, evaluation_kwargs: Optional[Mapping[str, Any]] = None, metric: Optional[str] = None, filter_validation_when_testing: bool = True, # 9. Tracking result_tracker: HintType[ResultTracker] = None, result_tracker_kwargs: Optional[Mapping[str, Any]] = None, # 6. Misc device: Hint[torch.device] = None, # Optuna Study Settings storage: Hint[BaseStorage] = None, sampler: HintType[BaseSampler] = None, sampler_kwargs: Optional[Mapping[str, Any]] = None, pruner: HintType[BasePruner] = None, pruner_kwargs: Optional[Mapping[str, Any]] = None, study_name: Optional[str] = None, direction: Optional[str] = None, load_if_exists: bool = False, # Optuna Optimization Settings n_trials: Optional[int] = None, timeout: Optional[int] = None, gc_after_trial: Optional[bool] = None, n_jobs: Optional[int] = None, save_model_directory: Optional[str] = None, ) -> HpoPipelineResult: """Train a model on the given dataset. :param dataset: The name of the dataset (a key for the :data:`pykeen.datasets.dataset_resolver`) or the :class:`pykeen.datasets.Dataset` instance. Alternatively, the training triples factory (``training``), testing triples factory (``testing``), and validation triples factory (``validation``; optional) can be specified. :param dataset_kwargs: The keyword arguments passed to the dataset upon instantiation :param training: A triples factory with training instances or path to the training file if a a dataset was not specified :param testing: A triples factory with test instances or path to the test file if a dataset was not specified :param validation: A triples factory with validation instances or path to the validation file if a dataset was not specified :param evaluation_entity_whitelist: Optional restriction of evaluation to triples containing *only* these entities. Useful if the downstream task is only interested in certain entities, but the relational patterns with other entities improve the entity embedding quality. Passed to :func:`pykeen.pipeline.pipeline`. :param evaluation_relation_whitelist: Optional restriction of evaluation to triples containing *only* these relations. Useful if the downstream task is only interested in certain relation, but the relational patterns with other relations improve the entity embedding quality. Passed to :func:`pykeen.pipeline.pipeline`. :param model: The name of the model or the model class to pass to :func:`pykeen.pipeline.pipeline` :param model_kwargs: Keyword arguments to pass to the model class on instantiation :param model_kwargs_ranges: Strategies for optimizing the models' hyper-parameters to override the defaults :param loss: The name of the loss or the loss class to pass to :func:`pykeen.pipeline.pipeline` :param loss_kwargs: Keyword arguments to pass to the loss on instantiation :param loss_kwargs_ranges: Strategies for optimizing the losses' hyper-parameters to override the defaults :param regularizer: The name of the regularizer or the regularizer class to pass to :func:`pykeen.pipeline.pipeline` :param regularizer_kwargs: Keyword arguments to pass to the regularizer on instantiation :param regularizer_kwargs_ranges: Strategies for optimizing the regularizers' hyper-parameters to override the defaults :param optimizer: The name of the optimizer or the optimizer class. Defaults to :class:`torch.optim.Adagrad`. :param optimizer_kwargs: Keyword arguments to pass to the optimizer on instantiation :param optimizer_kwargs_ranges: Strategies for optimizing the optimizers' hyper-parameters to override the defaults :param lr_scheduler: The name of the lr_scheduler or the lr_scheduler class. :param lr_scheduler_kwargs: Keyword arguments to pass to the lr_scheduler on instantiation :param lr_scheduler_kwargs_ranges: Strategies for optimizing the lr_schedulers' hyper-parameters to override the defaults :param training_loop: The name of the training approach (``'slcwa'`` or ``'lcwa'``) or the training loop class to pass to :func:`pykeen.pipeline.pipeline` :param training_loop_kwargs: additional keyword-based parameters passed to the training loop upon instantiation. :param negative_sampler: The name of the negative sampler (``'basic'`` or ``'bernoulli'``) or the negative sampler class to pass to :func:`pykeen.pipeline.pipeline`. Only allowed when training with sLCWA. :param negative_sampler_kwargs: Keyword arguments to pass to the negative sampler class on instantiation :param negative_sampler_kwargs_ranges: Strategies for optimizing the negative samplers' hyper-parameters to override the defaults :param epochs: A shortcut for setting the ``num_epochs`` key in the ``training_kwargs`` dict. :param training_kwargs: Keyword arguments to pass to the training loop's train function on call :param training_kwargs_ranges: Strategies for optimizing the training loops' hyper-parameters to override the defaults. Can not specify ranges for batch size if early stopping is enabled. :param stopper: What kind of stopping to use. Default to no stopping, can be set to 'early'. :param stopper_kwargs: Keyword arguments to pass to the stopper upon instantiation. :param evaluator: The name of the evaluator or an evaluator class. Defaults to :class:`pykeen.evaluation.RankBasedEvaluator`. :param evaluator_kwargs: Keyword arguments to pass to the evaluator on instantiation :param evaluation_kwargs: Keyword arguments to pass to the evaluator's evaluate function on call :param filter_validation_when_testing: If true, during evaluating on the test dataset, validation triples are added to the set of known positive triples, which are filtered out when performing filtered evaluation following the approach described by [bordes2013]_. Defaults to true. :param result_tracker: The ResultsTracker class or name :param result_tracker_kwargs: The keyword arguments passed to the results tracker on instantiation :param metric: The metric to optimize over. Defaults to mean reciprocal rank. :param n_jobs: The number of parallel jobs. If this argument is set to :obj:`-1`, the number is set to CPU counts. If none, defaults to 1. :param save_model_directory: If given, the final model of each trial is saved under this directory. :param storage: the study's storage, cf. :func:`` :param sampler: the sampler, or a hint thereof, cf. :func:`` :param sampler_kwargs: additional keyword-based parameters for the sampler :param pruner: the pruner, or a hint thereof, cf. :func:`` :param pruner_kwargs: additional keyword-based parameters for the pruner :param device: the device to use. :param study_name: the study's name, cf. :func:`` :param direction: The direction of optimization. Because the default metric is mean reciprocal rank, the default direction is ``maximize``. cf. :func:`` :param load_if_exists: whether to load the study if it already exists, cf. :func:`` :param n_trials: the number of trials, cf. :meth:``. :param timeout: the timeout, cf. :meth:``. :param gc_after_trial: the garbage collection after trial, cf. :meth:``. :param n_jobs: the number of jobs, cf. :meth:``. Defaults to 1. :return: the optimization result :raises ValueError: if early stopping is enabled, but the number of epochs is to be optimized, too. """ if direction is None: # TODO: use metric.increasing to determine default direction direction = "maximize" study = create_study( storage=storage, sampler=sampler_resolver.make(sampler, sampler_kwargs), pruner=pruner_resolver.make(pruner, pruner_kwargs), study_name=study_name, direction=direction, load_if_exists=load_if_exists, ) # 0. Metadata/Provenance study.set_user_attr("pykeen_version", get_version()) study.set_user_attr("pykeen_git_hash", get_git_hash()) # 1. Dataset _set_study_dataset( study=study, dataset=dataset, training=training, testing=testing, validation=validation, ) # 2. Model model_cls: Type[Model] = model_resolver.lookup(model) study.set_user_attr("model", model_resolver.normalize_cls(model_cls))"Using model: {model_cls}") # 3. Loss loss_cls: Type[Loss] = model_cls.loss_default if loss is None else loss_resolver.lookup(loss) study.set_user_attr("loss", loss_resolver.normalize_cls(loss_cls))"Using loss: {loss_cls}") # 4. Regularizer regularizer_cls: Optional[Type[Regularizer]] if regularizer is not None: regularizer_cls = regularizer_resolver.lookup(regularizer) elif getattr(model_cls, "regularizer_default", None): regularizer_cls = model_cls.regularizer_default # type:ignore else: regularizer_cls = None if regularizer_cls: study.set_user_attr("regularizer", regularizer_cls.get_normalized_name())"Using regularizer: {regularizer_cls}") # 5. Optimizer optimizer_cls: Type[Optimizer] = optimizer_resolver.lookup(optimizer) study.set_user_attr("optimizer", optimizer_resolver.normalize_cls(optimizer_cls))"Using optimizer: {optimizer_cls}") # 5.1 Learning Rate Scheduler lr_scheduler_cls: Optional[Type[LRScheduler]] = None if lr_scheduler is not None: lr_scheduler_cls = lr_scheduler_resolver.lookup(lr_scheduler) study.set_user_attr("lr_scheduler", lr_scheduler_resolver.normalize_cls(lr_scheduler_cls))"Using lr_scheduler: {lr_scheduler_cls}") # 6. Training Loop training_loop_cls: Type[TrainingLoop] = training_loop_resolver.lookup(training_loop) study.set_user_attr("training_loop", training_loop_cls.get_normalized_name())"Using training loop: {training_loop_cls}") negative_sampler_cls: Optional[Type[NegativeSampler]] if training_loop_cls is SLCWATrainingLoop: negative_sampler_cls = negative_sampler_resolver.lookup(negative_sampler) assert negative_sampler_cls is not None study.set_user_attr("negative_sampler", negative_sampler_cls.get_normalized_name())"Using negative sampler: {negative_sampler_cls}") else: negative_sampler_cls = None # 7. Training if epochs is not None: training_kwargs = {} if training_kwargs is None else dict(training_kwargs) training_kwargs["num_epochs"] = epochs stopper_cls: Type[Stopper] = stopper_resolver.lookup(stopper) if stopper_cls is EarlyStopper and training_kwargs_ranges and "epochs" in training_kwargs_ranges: raise ValueError("can not use early stopping while optimizing epochs") # 8. Evaluation evaluator_cls = evaluator_resolver.lookup(evaluator) study.set_user_attr("evaluator", evaluator_cls.get_normalized_name())"Using evaluator: {evaluator_cls}") resolved_metric = evaluator_cls.metric_result_cls.key_to_string(metric) study.set_user_attr("metric", resolved_metric)"Attempting to {direction} {resolved_metric}") study.set_user_attr("filter_validation_when_testing", filter_validation_when_testing)"Filter validation triples when testing: %s", filter_validation_when_testing) # 9. Tracking if not isinstance(result_tracker, ResultTracker): result_tracker = tracker_resolver.lookup(result_tracker) objective = Objective( # 1. Dataset dataset=dataset, dataset_kwargs=dataset_kwargs, training=training, testing=testing, validation=validation, evaluation_entity_whitelist=evaluation_entity_whitelist, evaluation_relation_whitelist=evaluation_relation_whitelist, # 2. Model model=model_cls, model_kwargs=model_kwargs, model_kwargs_ranges=model_kwargs_ranges, # 3. Loss loss=loss_cls, loss_kwargs=loss_kwargs, loss_kwargs_ranges=loss_kwargs_ranges, # 4. Regularizer regularizer=regularizer_cls, regularizer_kwargs=regularizer_kwargs, regularizer_kwargs_ranges=regularizer_kwargs_ranges, # 5. Optimizer optimizer=optimizer_cls, optimizer_kwargs=optimizer_kwargs, optimizer_kwargs_ranges=optimizer_kwargs_ranges, # 5.1 Learning Rate Scheduler lr_scheduler=lr_scheduler_cls, lr_scheduler_kwargs=lr_scheduler_kwargs, lr_scheduler_kwargs_ranges=lr_scheduler_kwargs_ranges, # 6. Training Loop training_loop=training_loop_cls, training_loop_kwargs=training_loop_kwargs, negative_sampler=negative_sampler_cls, negative_sampler_kwargs=negative_sampler_kwargs, negative_sampler_kwargs_ranges=negative_sampler_kwargs_ranges, # 7. Training training_kwargs=training_kwargs, training_kwargs_ranges=training_kwargs_ranges, stopper=stopper_cls, stopper_kwargs=stopper_kwargs, # 8. Evaluation evaluator=evaluator_cls, evaluator_kwargs=evaluator_kwargs, evaluation_kwargs=evaluation_kwargs, filter_validation_when_testing=filter_validation_when_testing, # 9. Tracker result_tracker=result_tracker, result_tracker_kwargs=result_tracker_kwargs, # Optuna Misc. metric=resolved_metric, save_model_directory=save_model_directory, # Pipeline Misc. device=device, ) # Invoke optimization of the objective function. study.optimize( cast(Callable[[Trial], float], objective), n_trials=n_trials, timeout=timeout, gc_after_trial=gc_after_trial, n_jobs=n_jobs or 1, catch=(MemoryError, RuntimeError), ) return HpoPipelineResult( study=study, objective=objective, )
def _get_kwargs( trial: Trial, prefix: str, *, default_kwargs_ranges: Mapping[str, Any], kwargs: Optional[Mapping[str, Any]] = None, kwargs_ranges: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, Any]: _kwargs_ranges = dict(default_kwargs_ranges) if kwargs_ranges is not None: _kwargs_ranges.update(kwargs_ranges) return suggest_kwargs( trial=trial, prefix=prefix, kwargs_ranges=_kwargs_ranges, kwargs=kwargs, ) def suggest_kwargs( trial: Trial, prefix: str, kwargs_ranges: Mapping[str, Any], kwargs: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, Any]: """ Suggest parameters from given dictionaries. :param trial: the optuna trial :param prefix: the prefix to be prepended to the name :param kwargs: a dictionary of fixed parameters :param kwargs_ranges: a dictionary of parameters to be sampled with their ranges. :return: a dictionary with fixed and sampled parameters """ _kwargs: Dict[str, Any] = {} if kwargs: _kwargs.update(kwargs) for name, info in kwargs_ranges.items(): if name in _kwargs: continue # has been set by default, won't be suggested prefixed_name = f"{prefix}.{name}" # TODO: make it even easier to specify categorical strategies just as lists # if isinstance(info, (tuple, list, set)): # info = dict(type='categorical', choices=list(info)) dtype, low, high = info["type"], info.get("low"), info.get("high") log = info.get("log") in {True, "TRUE", "True", "true", "t", "YES", "Yes", "yes", "y"} if dtype in {int, "int"}: scale = info.get("scale") if scale in {"power_two", "power"}: _kwargs[name] = suggest_discrete_power_int( trial=trial, name=prefixed_name, low=low, high=high, base=info.get("q") or info.get("base") or 2, ) elif scale is None or scale == "linear": # get log from info - could either be a boolean or string _kwargs[name] = trial.suggest_int( name=prefixed_name, low=low, high=high, step=info.get("q") or info.get("step") or 1, log=log, ) else: logger.warning(f"Unhandled scale {scale} for parameter {name} of data type {dtype}") elif dtype in {float, "float"}: _kwargs[name] = trial.suggest_float( name=prefixed_name, low=low, high=high, step=info.get("q") or info.get("step"), log=log, ) elif dtype == "categorical": choices = info["choices"] _kwargs[name] = trial.suggest_categorical(name=prefixed_name, choices=choices) elif dtype in {bool, "bool"}: _kwargs[name] = trial.suggest_categorical(name=prefixed_name, choices=[True, False]) else: logger.warning(f"Unhandled data type ({dtype}) for parameter {name}") return _kwargs def suggest_discrete_power_int(trial: Trial, name: str, low: int, high: int, base: int = 2) -> int: """Suggest an integer in the given range [2^low, 2^high].""" if high <= low: raise Exception(f"Upper bound {high} is not greater than lower bound {low}.") choices = [base**i for i in range(low, high + 1)] return cast(int, trial.suggest_categorical(name=name, choices=choices)) def _set_study_dataset( study: Study, *, dataset: Union[None, str, Dataset, Type[Dataset]] = None, training: Union[None, str, CoreTriplesFactory] = None, testing: Union[None, str, CoreTriplesFactory] = None, validation: Union[None, str, CoreTriplesFactory] = None, ): if dataset is not None: if training is not None or testing is not None or validation is not None: raise ValueError("Cannot specify dataset and training, testing and validation") elif isinstance(dataset, (str, pathlib.Path)): if isinstance(dataset, str) and has_dataset(dataset): study.set_user_attr("dataset", dataset_resolver.normalize(dataset)) else: # otherwise, dataset refers to a file that should be automatically split study.set_user_attr("dataset", str(dataset)) elif isinstance(dataset, Dataset) or (isinstance(dataset, type) and issubclass(dataset, Dataset)): # this could be custom data, so don't store anything. However, it's possible to check if this # was a pre-registered dataset. If that's the desired functionality, we can uncomment the following: # dataset_name = dataset.get_normalized_name() # this works both on instances and classes # if has_dataset(dataset_name): # study.set_user_attr('dataset', dataset_name) pass else: raise TypeError(f"Dataset is invalid type: ({type(dataset)}) {dataset}") else: if isinstance(training, (str, pathlib.Path)): study.set_user_attr("training", str(training)) if isinstance(testing, (str, pathlib.Path)): study.set_user_attr("testing", str(testing)) if isinstance(validation, (str, pathlib.Path)): study.set_user_attr("validation", str(validation))