Interaction
- class Interaction(*args, **kwargs)[source]
Bases:
Module
,Generic
[HeadRepresentation
,RelationRepresentation
,TailRepresentation
],ABC
Base class for interaction functions.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Attributes Summary
Get all the relevant dimension keys.
The symbolic shapes for entity representations
Return the entity representation indices used for the head representations.
Return the symbolic shape for head entity representations.
whether the interaction is defined on complex input
The symbolic shapes for relation representations
Return the entity representation indices used for the tail representations.
Return the symbolic shape for tail entity representations.
the interaction's value range (for unrestricted input)
Methods Summary
forward
(h, r, t)Compute broadcasted triple scores given broadcasted representations for head, relation and tails.
Reset parameters the interaction function may have.
score
(h, r, t[, slice_size, slice_dim])Compute broadcasted triple scores with optional slicing.
score_h
(all_entities, r, t[, slice_size])Score all head entities.
score_hrt
(h, r, t)Score a batch of triples.
score_r
(h, all_relations, t[, slice_size])Score all relations.
score_t
(h, r, all_entities[, slice_size])Score all tail entities.
Attributes Documentation
- dimensions
Get all the relevant dimension keys.
This draws from
Interaction.entity_shape
, andInteraction.relation_shape
.- Returns:
a set of strings representing the dimension keys.
- head_indices
Return the entity representation indices used for the head representations.
- head_shape
Return the symbolic shape for head entity representations.
- tail_indices
Return the entity representation indices used for the tail representations.
- tail_shape
Return the symbolic shape for tail entity representations.
- value_range: ClassVar[ValueRange] = ValueRange(lower=None, lower_inclusive=False, upper=None, upper_inclusive=False)
the interaction’s value range (for unrestricted input)
Methods Documentation
- abstract forward(h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation) Tensor [source]
Compute broadcasted triple scores given broadcasted representations for head, relation and tails.
In general, each interaction function (class) expects a certain format for each of head, relation and tail representations. This format is composed of the number and the shape of the representations.
Many simple interaction functions such as
TransEInteraction
operate on a single representation, however there are also interactions such asTransDInteraction
, which requires two representations for each slot, orPairREInteraction
, which requires two relation representations, but only a single representation for head and tail entity respectively.Each individual representation has a shape. This can be a simple \(d\)-dimensional vector, but also comprise matrices, or even high-order tensors.
This method supports the general batched calculation, i.e., each of the representations can have a preceding batch dimensions. Those batch dimensions do not necessarily need to be exactly the same, but they need to be broadcastable. A good explanation of broadcasting rules can be found in NumPy’s documentation.
See also
Representations for an overview about different ways how to obtain individual representations.
- Parameters:
h (HeadRepresentation) – shape:
(*batch_dims, *dims)
The head representations.r (RelationRepresentation) – shape:
(*batch_dims, *dims)
The relation representations.t (TailRepresentation) – shape:
(*batch_dims, *dims)
The tail representations.
- Returns:
shape: batch_dims The scores.
- Return type:
- score(h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, slice_size: int | None = None, slice_dim: int = 1) Tensor [source]
Compute broadcasted triple scores with optional slicing.
Note
At most one of the slice sizes may be not None.
Todo
we could change that to slicing along multiple dimensions, if necessary
- Parameters:
h (HeadRepresentation) – shape: (*batch_dims, *dims) The head representations.
r (RelationRepresentation) – shape: (*batch_dims, *dims) The relation representations.
t (TailRepresentation) – shape: (*batch_dims, *dims) The tail representations.
slice_size (int | None) – The slice size.
slice_dim (int) – The dimension along which to slice. From {0, …, len(batch_dims)}
- Returns:
shape: batch_dims The scores.
- Return type:
- score_h(all_entities: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, slice_size: int | None = None) Tensor [source]
Score all head entities.
- Parameters:
all_entities (HeadRepresentation) – shape: (num_entities, d_e) The head representations.
r (RelationRepresentation) – shape: (batch_size, d_r) The relation representations.
t (TailRepresentation) – shape: (batch_size, d_e) The tail representations.
slice_size (int | None) – The slice size.
- Returns:
shape: (batch_size, num_entities) The scores.
- Return type:
- score_hrt(h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation) Tensor [source]
Score a batch of triples.
- Parameters:
h (HeadRepresentation) – shape: (batch_size, d_e) The head representations.
r (RelationRepresentation) – shape: (batch_size, d_r) The relation representations.
t (TailRepresentation) – shape: (batch_size, d_e) The tail representations.
- Returns:
shape: (batch_size, 1) The scores.
- Return type:
- score_r(h: HeadRepresentation, all_relations: RelationRepresentation, t: TailRepresentation, slice_size: int | None = None) Tensor [source]
Score all relations.
- Parameters:
h (HeadRepresentation) – shape: (batch_size, d_e) The head representations.
all_relations (RelationRepresentation) – shape: (num_relations, d_r) The relation representations.
t (TailRepresentation) – shape: (batch_size, d_e) The tail representations.
slice_size (int | None) – The slice size.
- Returns:
shape: (batch_size, num_entities) The scores.
- Return type:
- score_t(h: HeadRepresentation, r: RelationRepresentation, all_entities: TailRepresentation, slice_size: int | None = None) Tensor [source]
Score all tail entities.
- Parameters:
h (HeadRepresentation) – shape: (batch_size, d_e) The head representations.
r (RelationRepresentation) – shape: (batch_size, d_r) The relation representations.
all_entities (TailRepresentation) – shape: (num_entities, d_e) The tail representations.
slice_size (int | None) – The slice size.
- Returns:
shape: (batch_size, num_entities) The scores.
- Return type: