BoxEInteraction¶
- class BoxEInteraction(tanh_map=True, p=2, power_norm=False)[source]¶
Bases:
NormBasedInteraction
[Tuple
[FloatTensor
,FloatTensor
],Tuple
[FloatTensor
,FloatTensor
,FloatTensor
,FloatTensor
,FloatTensor
,FloatTensor
],Tuple
[FloatTensor
,FloatTensor
]]The BoxE interaction from [abboud2020].
Entities are represented by two \(d\)-dimensional vectors describing the base position as well as the translational bump, which translates all the entities co-occuring in a fact with this entity from their base positions to their final embeddings, called “bumping”.
Relations are represented as a fixed number of hyper-rectangles corresponding to the relation’s arity. Since we are only considering single-hop link predition here, the arity is always two, i.e., one box for the head position and another one for the tail position. There are different possibilities to parametrize a hyper-rectangle, where the most common may be its description as the coordinate of to opposing vertices. BoxE suggests a different parametrization:
each box has a base position given by its center
each box has an extent in each dimension. This size is further factored in
a scalar global scaling factor
a normalized extent in each dimension, i.e., the extents sum to one
Instantiate the interaction module.
- Parameters:
Attributes Summary
The symbolic shapes for entity representations
The symbolic shapes for relation representations
Methods Summary
boxe_kg_arity_position_score
(entity_pos, ...)Perform the BoxE computation at a single arity position.
compute_box
(base, delta, size)Compute the lower and upper corners of a resulting box.
func
(h_pos, h_bump, rh_base, rh_delta, ...)Evaluate the BoxE interaction function from [abboud2020].
point_to_box_distance
(points, box_lows, ...)Compute the point to box distance function proposed by [abboud2020] in an element-wise fashion.
product_normalize
(x[, dim])Normalize a tensor along a given dimension so that the geometric mean is 1.0.
Attributes Documentation
- relation_shape: Sequence[str] = ('d', 'd', 's', 'd', 'd', 's')¶
The symbolic shapes for relation representations
Methods Documentation
- classmethod boxe_kg_arity_position_score(entity_pos, other_entity_bump, relation_box, tanh_map, p, power_norm)[source]¶
Perform the BoxE computation at a single arity position.
Note
this computation is parallelizable across all positions
Note
entity_pos, other_entity_bump, relation_box_low and relation_box_high have to be in broadcastable shape.
- Parameters:
entity_pos (
FloatTensor
) – shape:(*s_p, d)
This is the base entity position of the entity appearing in the target position. For example, for a fact \(r(h, t)\) and the head arity position, entity_pos is the base position of \(h\).other_entity_bump (
FloatTensor
) – shape:(*s_b, d)
This is the bump of the entity at the other position in the fact. For example, given a fact \(r(h, t)\) and the head arity position, other_entity_bump is the bump of \(t\).relation_box (
Tuple
[FloatTensor
,FloatTensor
]) – shape:(*s_r, d)
The lower/upper corner of the relation box at the target arity position.tanh_map (
bool
) – whether to apply the tanh map regularizerp (
int
) – The norm order to apply across dimensions to compute overall position score.power_norm (
bool
) – whether to use the powered norm instead
- Return type:
FloatTensor
- Returns:
shape:
*s
Arity-position score for the entity relative to the target relation box. Larger is better. The shape is the broadcasted shape from position, bump and box, where the last dimension has been removed.
- classmethod compute_box(base, delta, size)[source]¶
Compute the lower and upper corners of a resulting box.
- Parameters:
base (
FloatTensor
) – shape:(*, d)
the base position (box center) of the input relation embeddingsdelta (
FloatTensor
) – shape:(*, d)
the base shape of the input relation embeddingssize (
FloatTensor
) – shape:(*, d)
the size scalar vectors of the input relation embeddings
- Return type:
Tuple
[FloatTensor
,FloatTensor
]- Returns:
shape:
(*, d)
each lower and upper bounds of the box whose embeddings are provided as input.
- static func(h_pos, h_bump, rh_base, rh_delta, rh_size, rt_base, rt_delta, rt_size, t_pos, t_bump, tanh_map=True, p=2, power_norm=False)[source]¶
Evaluate the BoxE interaction function from [abboud2020].
- Parameters:
h_pos (
FloatTensor
) – shape: (*batch_dims, d) the head entity positionh_bump (
FloatTensor
) – shape: (*batch_dims, d) the head entity bumprh_base (
FloatTensor
) – shape: (*batch_dims, d) the relation-specific head box base positionrh_delta (
FloatTensor
) – shape: (*batch_dims, d) # the relation-specific head box base shape (normalized to have a volume of 1):rh_size (
FloatTensor
) – shape: (*batch_dims, 1) the relation-specific head box size (a scalar)rt_base (
FloatTensor
) – shape: (*batch_dims, d) the relation-specific tail box base positionrt_delta (
FloatTensor
) – shape: (*batch_dims, d) # the relation-specific tail box base shape (normalized to have a volume of 1):rt_size (
FloatTensor
) – shape: (*batch_dims, d) the relation-specific tail box sizet_pos (
FloatTensor
) – shape: (*batch_dims, d) the tail entity positiont_bump (
FloatTensor
) – shape: (*batch_dims, d) the tail entity bumptanh_map (
bool
) – whether to apply the tanh mappingp (
int
) – the order of the norm to applypower_norm (
bool
) – whether to use the p-th power of the p-norm instead
- Return type:
FloatTensor
- Returns:
shape: batch_dims The scores.
- static point_to_box_distance(points, box_lows, box_highs)[source]¶
Compute the point to box distance function proposed by [abboud2020] in an element-wise fashion.
- Parameters:
points (
FloatTensor
) – shape:(*, d)
the positions of the points being scored against boxesbox_lows (
FloatTensor
) – shape:(*, d)
the lower corners of the boxesbox_highs (
FloatTensor
) – shape:(*, d)
the upper corners of the boxes
- Return type:
FloatTensor
- Returns:
Element-wise distance function scores as per the definition above
Given points \(p\), box_lows \(l\), and box_highs \(h\), the following quantities are defined:
Width \(w\) is the difference between the upper and lower box bound: \(w = h - l\)
Box centers \(c\) are the mean of the box bounds: \(c = (h + l) / 2\)
Finally, the point to box distance \(dist(p,l,h)\) is defined as the following piecewise function:
\[\begin{split}dist(p,l,h) = \begin{cases} |p-c|/(w+1) & l <= p <+ h \\ |p-c|*(w+1) - 0.5*w*((w+1)-1/(w+1)) & otherwise \\ \end{cases}\end{split}\]
- static product_normalize(x, dim=-1)[source]¶
Normalize a tensor along a given dimension so that the geometric mean is 1.0.
- Parameters:
x (
FloatTensor
) – shape: s An input tensordim (
int
) – the dimension along which to normalize the tensor
- Return type:
FloatTensor
- Returns:
shape: s An output tensor where the given dimension is normalized to have a geometric mean of 1.0.