NTN
- class NTN(*, embedding_dim: int = 100, num_slices: int = 4, non_linearity: str | Module | type[Module] | None = None, non_linearity_kwargs: Mapping[str, Any] | None = None, entity_initializer: str | Callable[[Tensor], Tensor] | None = None, **kwargs)[source]
Bases:
ERModel[Tensor,tuple[Tensor,Tensor,Tensor,Tensor,Tensor],Tensor]An implementation of NTN from [socher2013].
NTN represents entities using a \(d\)-dimensional vector. Relations are represented by
a \(k \times d \times d\)-dimensional tensor, \(\mathbf{W} \in \mathbb{R}^{k \times d \times d}\),
a \(2k \times d\)-dimensional matrix, \(\mathbf{V} \in \mathbb{R}^{k \times 2d}\), and
two \(k\)-dimensional vectors, \(\mathbf{b}, \mathbf{u} \in \mathbb{R}^{k}\).
Denoting the number of entities by \(E\) and the number of relations by \(R\), the total number of parameters is thus given by
\[dE + k(d^2 + 2d + 2)R\]All representations are stored as
Embedding.NTNInteractionis used as interaction upon those representations.Note
We split the original \(k \times 2d\)-dimensional \(\mathbf{V}\) matrix into two parts of shape \(k \times d\) to support more efficient 1:n scoring, e.g., in the
score_h()orscore_t()setting.See also
Original Implementation (Matlab): https://github.com/khurram18/NeuralTensorNetworks
TensorFlow: https://github.com/dddoss/tensorflow-socher-ntn
Keras: https://github.com/dapurv5/keras-neural-tensor-layer(Keras)
Initialize NTN.
- Parameters:
embedding_dim (int) – The entity embedding dimension \(d\). Is usually \(d \in [50, 350]\).
num_slices (int) – The number of slices in the parameters
non_linearity (str | Module | type[Module] | None) – A non-linear activation function. Defaults to the hyperbolic tangent
torch.nn.Tanh.non_linearity_kwargs (Mapping[str, Any] | None) – If the
non_linearityis passed as a class, these keyword arguments are used during its instantiation.entity_initializer (str | Callable[[Tensor], Tensor] | None) – Entity initializer function. Defaults to
torch.nn.init.uniform_()kwargs – Remaining keyword arguments to forward to
ERModel
Note
The parameter pair
(non_linearity, non_linearity_kwargs)is used forclass_resolver.contrib.torch.activation_resolverAn explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.
Attributes Summary
The default strategy for optimizing the model's hyper-parameters
Attributes Documentation