NTNInteraction
- class NTNInteraction(activation: str | Module | type[Module] | None = None, activation_kwargs: Mapping[str, Any] | None = None)[source]
Bases:
FunctionalInteraction
[Tensor
,tuple
[Tensor
,Tensor
,Tensor
,Tensor
,Tensor
],Tensor
]A stateful module for the NTN interaction function.
Initialize NTN with the given non-linear activation function.
- Parameters:
activation (HintOrType[nn.Module]) – A non-linear activation function. Defaults to the hyperbolic tangent
torch.nn.Tanh
if None, otherwise uses thepykeen.utils.activation_resolver
for lookup.activation_kwargs (Mapping[str, Any] | None) – If the
activation
is passed as a class, these keyword arguments are used during its instantiation.
Attributes Summary
The symbolic shapes for relation representations
Methods Summary
func
(t, w, vh, vt, b, u, activation)Evaluate the NTN interaction function.
Attributes Documentation
- relation_shape: Sequence[str] = ('kdd', 'kd', 'kd', 'k', 'k')
The symbolic shapes for relation representations
Methods Documentation
- func(t: Tensor, w: Tensor, vh: Tensor, vt: Tensor, b: Tensor, u: Tensor, activation: Module) Tensor
Evaluate the NTN interaction function.
\[f(h,r,t) = u_r^T act(h W_r t + V_r h + V_r' t + b_r)\]- Parameters:
h (Tensor) – shape: (*batch_dims, dim) The head representations.
w (Tensor) – shape: (*batch_dims, k, dim, dim) The relation specific transformation matrix W_r.
vh (Tensor) – shape: (*batch_dims, k, dim) The head transformation matrix V_h.
vt (Tensor) – shape: (*batch_dims, k, dim) The tail transformation matrix V_h.
b (Tensor) – shape: (*batch_dims, k) The relation specific offset b_r.
u (Tensor) – shape: (*batch_dims, k) The relation specific final linear transformation b_r.
t (Tensor) – shape: (*batch_dims, dim) The tail representations.
activation (Module) – The activation function.
- Returns:
shape: batch_dims The scores.
- Return type: