DirectionAverageInteraction
- class DirectionAverageInteraction(base: str | Interaction[Tensor, Tensor, Tensor] | type[Interaction[Tensor, Tensor, Tensor]], base_kwargs: Mapping[str, Any] | None = None)[source]
Bases:
Interaction
[tuple
[Tensor
,Tensor
],tuple
[Tensor
,Tensor
],tuple
[Tensor
,Tensor
]]The directional average interaction module.
This can be considered as a generalization of the SimplE interaction module that can be parametrized with any other interaction module, rather than just
pykeen.nn.modules.DistMultInteraction
.A separate representation is learned for each entity \(e \in \mathcal{E}\) for when it appears as the subject of a triple \(\mathbf{e}_h \in \mathbb{R}^d\) and as the object of a triple \(\mathbf{e}_t \in \mathbb{R}^d\). Similarly, two representations are learned for each relationship for a forward \(\textbf{r}_{\rightarrow}\) and backward triple \(\textbf{r}_{\leftarrow}\).
The score is then obtained by averaging the forward and the backward interaction function value:
\[\frac{ f(\textbf{h}_{h}, \textbf{r}_{\rightarrow}, \textbf{t}_{t}) + f(\textbf{t}_{h}, \textbf{r}_{\leftarrow}, \textbf{h}_{t}) }{2}\]Where
f
is the interaction model used. Ifpykeen.nn.modules.DistMultInteraction
is used, then this becomespykeen.nn.modules.SimplEInteraction
.Todo
can we generalize the type annotations for this from FloatTensor to HeadRepresentation, etc.?
Initialize the interaction module.
- Parameters:
base (LookupOrType[Interaction[FloatTensor, FloatTensor, FloatTensor]]) – the base interaction.
base_kwargs (OptionalKwargs) – keyword-based parameters used to instantiate the base interaction
Note
The parameter pair
(base, base_kwargs)
is used forinteraction_resolver
An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.
Methods Summary
forward
(h, r, t)Evaluate the interaction function.
Methods Documentation
- forward(h: tuple[Tensor, Tensor], r: tuple[Tensor, Tensor], t: tuple[Tensor, Tensor]) → Tensor[source]
Evaluate the interaction function.
See also
Interaction.forward
for a detailed description about the generic batched form of the interaction function.- Parameters:
h (tuple[Tensor, Tensor]) – shape:
(*batch_dims, d)
and(*batch_dims, d)
The head representations.r (tuple[Tensor, Tensor]) – shape:
(*batch_dims, d)
and(*batch_dims, d)
The relation representations.t (tuple[Tensor, Tensor]) – shape:
(*batch_dims, d)
and(*batch_dims, d)
The tail representations.
- Returns:
shape:
batch_dims
The scores.- Return type: