NTN
- class NTN(*, embedding_dim: int = 100, num_slices: int = 4, non_linearity: str | Module | type[Module] | None = None, non_linearity_kwargs: Mapping[str, Any] | None = None, entity_initializer: str | Callable[[Tensor], Tensor] | None = None, **kwargs)[source]
Bases:
ERModel
[Tensor
,tuple
[Tensor
,Tensor
,Tensor
,Tensor
,Tensor
],Tensor
]An implementation of NTN from [socher2013].
NTN represents entities using a \(d\)-dimensional vector. Relations are represented by
a \(k \times d \times d\)-dimensional tensor, \(\mathbf{W} \in \mathbb{R}^{k \times d \times d}\),
a \(2k \times d\)-dimensional matrix, \(\mathbf{V} \in \mathbb{R}^{k \times 2d}\), and
two \(k\)-dimensional vectors, \(\mathbf{b}, \mathbf{u} \in \mathbb{R}^{k}\).
Denoting the number of entities by \(E\) and the number of relations by \(R\), the total number of parameters is thus given by
\[dE + k(d^2 + 2d + 2)R\]All representations are stored as
Embedding
.NTNInteraction
is used as interaction upon those representations.Note
We split the original \(k \times 2d\)-dimensional \(\mathbf{V}\) matrix into two parts of shape \(k \times d\) to support more efficient 1:n scoring, e.g., in the
score_h()
orscore_t()
setting.See also
Original Implementation (Matlab): https://github.com/khurram18/NeuralTensorNetworks
TensorFlow: https://github.com/dddoss/tensorflow-socher-ntn
Keras: https://github.com/dapurv5/keras-neural-tensor-layer(Keras)
Initialize NTN.
- Parameters:
embedding_dim (int) – The entity embedding dimension \(d\). Is usually \(d \in [50, 350]\).
num_slices (int) – The number of slices in the parameters
non_linearity (str | Module | type[Module] | None) – A non-linear activation function. Defaults to the hyperbolic tangent
torch.nn.Tanh
.non_linearity_kwargs (Mapping[str, Any] | None) – If the
non_linearity
is passed as a class, these keyword arguments are used during its instantiation.entity_initializer (str | Callable[[Tensor], Tensor] | None) – Entity initializer function. Defaults to
torch.nn.init.uniform_()
kwargs – Remaining keyword arguments to forward to
ERModel
Note
The parameter pair
(non_linearity, non_linearity_kwargs)
is used forclass_resolver.contrib.torch.activation_resolver
An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.
Attributes Summary
The default strategy for optimizing the model's hyper-parameters
Attributes Documentation