"""Implementation of the ComplexLiteral model."""
from collections.abc import Mapping
from typing import Any, ClassVar
import torch
import torch.nn as nn
from .base import LiteralModel
from ...constants import DEFAULT_DROPOUT_HPO_RANGE, DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...losses import BCEWithLogitsLoss, Loss
from ...nn import ComplexSeparatedCombination, ConcatProjectionCombination
from ...nn.modules import ComplExInteraction, Interaction
from ...triples import TriplesNumericLiteralsFactory
__all__ = [
"ComplExLiteral",
]
[docs]
class ComplExLiteral(LiteralModel):
"""An implementation of the LiteralE model with the ComplEx interaction from [kristiadi2018]_.
This module is a configuration of the general :class:`pykeen.models.LiteralModel` with the
:class:`pykeen.nn.modules.ComplExInteraction` and :class:`pykeen.nn.combinations.ComplExLiteralCombination`.
---
name: ComplEx Literal
citation:
author: Kristiadi
year: 2018
link: https://arxiv.org/abs/1802.00934
"""
#: The default strategy for optimizing the model's hyper-parameters
hpo_default: ClassVar[Mapping[str, Any]] = {
"embedding_dim": DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE,
"input_dropout": DEFAULT_DROPOUT_HPO_RANGE,
}
#: The default loss function class
loss_default: ClassVar[type[Loss]] = BCEWithLogitsLoss
#: The default parameters for the default loss function class
loss_default_kwargs: ClassVar[Mapping[str, Any]] = {}
interaction_cls: ClassVar[type[Interaction]] = ComplExInteraction
def __init__(
self,
triples_factory: TriplesNumericLiteralsFactory,
embedding_dim: int = 50,
input_dropout: float = 0.2,
**kwargs,
) -> None:
"""Initialize the model."""
super().__init__(
triples_factory=triples_factory,
interaction=self.interaction_cls,
entity_representations_kwargs=[
{
"shape": embedding_dim,
"initializer": nn.init.xavier_normal_,
"dtype": torch.complex64,
},
],
relation_representations_kwargs=[
{
"shape": embedding_dim,
"initializer": nn.init.xavier_normal_,
"dtype": torch.complex64,
},
],
combination=ComplexSeparatedCombination,
combination_kwargs={
# the individual combination for real/complex parts
"combination": ConcatProjectionCombination,
"combination_kwargs": {
"input_dims": [embedding_dim, triples_factory.literal_shape[0]],
"output_dim": embedding_dim,
"bias": True,
"dropout": input_dropout,
"activation": nn.Tanh,
"activation_kwargs": None,
},
},
**kwargs,
)