Interaction
- class Interaction(*args, **kwargs)[source]
Bases:
Module
,Generic
[HeadRepresentation
,RelationRepresentation
,TailRepresentation
],ABC
Base class for interaction functions.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Attributes Summary
The symbolic shapes for entity representations
whether the interaction is defined on complex input
The symbolic shapes for relation representations
Return the symbolic shape for tail entity representations.
the interaction's value range (for unrestricted input)
Methods Summary
forward
(h, r, t)Compute broadcasted triple scores given broadcasted representations for head, relation and tails.
Return all entity shapes (head & tail).
Get all of the relevant dimension keys.
Return the entity representation indices used for the head representations.
Reset parameters the interaction function may have.
score
(h, r, t[, slice_size, slice_dim])Compute broadcasted triple scores with optional slicing.
score_h
(all_entities, r, t[, slice_size])Score all head entities.
score_hrt
(h, r, t)Score a batch of triples.
score_r
(h, all_relations, t[, slice_size])Score all relations.
score_t
(h, r, all_entities[, slice_size])Score all tail entities.
Return the entity representation indices used for the tail representations.
Attributes Documentation
- tail_entity_shape
Return the symbolic shape for tail entity representations.
- value_range: ClassVar[ValueRange] = ValueRange(lower=None, lower_inclusive=False, upper=None, upper_inclusive=False)
the interaction’s value range (for unrestricted input)
Methods Documentation
- abstract forward(h, r, t)[source]
Compute broadcasted triple scores given broadcasted representations for head, relation and tails.
- Parameters:
h (~HeadRepresentation) – shape: (*batch_dims, *dims) The head representations.
r (~RelationRepresentation) – shape: (*batch_dims, *dims) The relation representations.
t (~TailRepresentation) – shape: (*batch_dims, *dims) The tail representations.
- Return type:
FloatTensor
- Returns:
shape: batch_dims The scores.
- classmethod get_dimensions()[source]
Get all of the relevant dimension keys.
This draws from
Interaction.entity_shape
,Interaction.relation_shape
, and in the case ofConvEInteraction
, theInteraction.tail_entity_shape
.
- score(h, r, t, slice_size=None, slice_dim=1)[source]
Compute broadcasted triple scores with optional slicing.
Note
At most one of the slice sizes may be not None.
# TODO: we could change that to slicing along multiple dimensions, if necessary
- Parameters:
h (~HeadRepresentation) – shape: (*batch_dims, *dims) The head representations.
r (~RelationRepresentation) – shape: (*batch_dims, *dims) The relation representations.
t (~TailRepresentation) – shape: (*batch_dims, *dims) The tail representations.
slice_dim (
int
) – The dimension along which to slice. From {0, …, len(batch_dims)}
- Return type:
FloatTensor
- Returns:
shape: batch_dims The scores.
- score_h(all_entities, r, t, slice_size=None)[source]
Score all head entities.
- Parameters:
- Return type:
FloatTensor
- Returns:
shape: (batch_size, num_entities) The scores.
- score_hrt(h, r, t)[source]
Score a batch of triples.
- Parameters:
h (~HeadRepresentation) – shape: (batch_size, d_e) The head representations.
r (~RelationRepresentation) – shape: (batch_size, d_r) The relation representations.
t (~TailRepresentation) – shape: (batch_size, d_e) The tail representations.
- Return type:
FloatTensor
- Returns:
shape: (batch_size, 1) The scores.
- score_r(h, all_relations, t, slice_size=None)[source]
Score all relations.
- Parameters:
- Return type:
FloatTensor
- Returns:
shape: (batch_size, num_entities) The scores.