- class CrossEInteraction(embedding_dim=50, combination_activation=<class 'torch.nn.modules.activation.Tanh'>, combination_activation_kwargs=None, combination_dropout=0.5)
A module wrapper for the CrossE interaction function.
Instantiate the interaction module.
int) – The embedding dimension.
The symbolic shapes for relation representations
func(r, c_r, t, bias, activation[, dropout])
Evaluate the interaction function of CrossE for the given representations from [zhang2019b].
- func(r, c_r, t, bias, activation, dropout=None)
Evaluate the interaction function of CrossE for the given representations from [zhang2019b].\[Dropout(Activation(c_r \odot h + c_r \odot h \odot r + b))^T t)\]
The representations have to be in a broadcastable shape.
The CrossE paper described an additional sigmoid activation as part of the interaction function. Since using a log-likelihood loss can cause numerical problems (due to explicitly calling sigmoid before log), we do not apply this in our implementation but rather opt for the numerically stable variant. However, the model itself has an option
predict_with_sigmoid, which can be used to enforce application of sigmoid during inference. This can also have an impact of rank-based evaluation, since limited numerical precision can lead to exactly equal scores for multiple choices. The definition of a rank is not unambiguous in such case, and there exist multiple competing variants how to break the ties. More information on this can be found in the documentation of rank-based evaluation.
FloatTensor) – shape: (*batch_dims, dim) The head representations.
FloatTensor) – shape: (*batch_dims, dim) The relation representations.
FloatTensor) – shape: (*batch_dims, dim) The relation-specific interaction vector.
FloatTensor) – shape: (*batch_dims, dim) The tail representations.
FloatTensor) – shape: (dim,) The combination bias.
Module) – The combination activation. Should be
torch.nn.Tanhfor consistency with the CrossE paper.
Dropout]) – Dropout applied after the combination.
- Return type
shape: batch_dims The scores.