# -*- coding: utf-8 -*-
"""Type hints for PyKEEN."""
from typing import Callable, Mapping, NamedTuple, Sequence, TypeVar, Union, cast
import numpy as np
import torch
__all__ = [
# General types
'Hint',
'Mutation',
'OneOrSequence',
# Triples
'LabeledTriples',
'MappedTriples',
'EntityMapping',
'RelationMapping',
# Others
'DeviceHint',
'TorchRandomHint',
# Tensor Functions
'Initializer',
'Normalizer',
'Constrainer',
'cast_constrainer',
# Tensors
'HeadRepresentation',
'RelationRepresentation',
'TailRepresentation',
# Dataclasses
'GaussianDistribution',
'ScorePack',
]
X = TypeVar('X')
Hint = Union[None, str, X]
#: A function that mutates the input and returns a new object of the same type as output
Mutation = Callable[[X], X]
OneOrSequence = Union[X, Sequence[X]]
LabeledTriples = np.ndarray
MappedTriples = torch.LongTensor
EntityMapping = Mapping[str, int]
RelationMapping = Mapping[str, int]
#: A function that can be applied to a tensor to initialize it
Initializer = Mutation[torch.FloatTensor]
#: A function that can be applied to a tensor to normalize it
Normalizer = Mutation[torch.FloatTensor]
#: A function that can be applied to a tensor to constrain it
Constrainer = Mutation[torch.FloatTensor]
[docs]def cast_constrainer(f) -> Constrainer:
"""Cast a constrainer function with :func:`typing.cast`."""
return cast(Constrainer, f)
#: A hint for a :class:`torch.device`
DeviceHint = Hint[torch.device]
#: A hint for a :class:`torch.Generator`
TorchRandomHint = Hint[torch.Generator]
#: A type variable for head representations used in :class:`pykeen.models.Model`,
#: :class:`pykeen.nn.modules.Interaction`, etc.
HeadRepresentation = TypeVar("HeadRepresentation", bound=OneOrSequence[torch.FloatTensor])
#: A type variable for relation representations used in :class:`pykeen.models.Model`,
#: :class:`pykeen.nn.modules.Interaction`, etc.
RelationRepresentation = TypeVar("RelationRepresentation", bound=OneOrSequence[torch.FloatTensor])
#: A type variable for tail representations used in :class:`pykeen.models.Model`,
#: :class:`pykeen.nn.modules.Interaction`, etc.
TailRepresentation = TypeVar("TailRepresentation", bound=OneOrSequence[torch.FloatTensor])
[docs]class GaussianDistribution(NamedTuple):
"""A gaussian distribution with diagonal covariance matrix."""
mean: torch.FloatTensor
diagonal_covariance: torch.FloatTensor
[docs]class ScorePack(NamedTuple):
"""A pair of result triples and scores."""
result: torch.LongTensor
scores: torch.FloatTensor