BoxEInteraction
- class BoxEInteraction(tanh_map: bool = True, p: int = 2, power_norm: bool = False)[source]
Bases:
NormBasedInteraction
[tuple
[Tensor
,Tensor
],tuple
[Tensor
,Tensor
,Tensor
,Tensor
,Tensor
,Tensor
],tuple
[Tensor
,Tensor
]]The BoxE interaction from [abboud2020].
Entities are represented by two \(d\)-dimensional vectors describing the base position as well as the translational bump, which translates all the entities co-occuring in a fact with this entity from their base positions to their final embeddings, called “bumping”.
Relations are represented as a fixed number of hyper-rectangles corresponding to the relation’s arity. Since we are only considering single-hop link predition here, the arity is always two, i.e., one box for the head position and another one for the tail position. There are different possibilities to parametrize a hyper-rectangle, where the most common may be its description as the coordinate of to opposing vertices. BoxE suggests a different parametrization for each box by
a base position given by its center, a \(d\)-dimensional vector \(\mathbf{c} \in \mathbb{R}^d\)
an extent in each dimension. This size is further factored in
a scalar global scalar scaling factor, \(s \in \mathbb{R}\)
a normalized extent in each dimension, i.e., the extents sum to one, given as \(\mathbf{e} \in \mathbb{R}^d\), with \(\|\mathbf{e}\| = 1\) and \(0 \leq \mathbf{e}_i\) for all \(i\).
Instantiate the interaction module.
See also
The parameter
p
andpower_norm
are directly passed toNormBasedInteraction
.- Parameters:
tanh_map (bool) – Whether to use tanh mapping after BoxE computation (defaults to true). The hyperbolic tangent mapping restricts the embedding space to the range [-1, 1], and thus this map implicitly regularizes the space to prevent loss reduction by growing boxes arbitrarily large.
p (int) – The norm used with
torch.linalg.vector_norm()
. Typically is 1 or 2.power_norm (bool) – Whether to use the p-th power of the \(L_p\) norm. It has the advantage of being differentiable around 0, and numerically more stable.
Attributes Summary
The symbolic shapes for entity representations
The symbolic shapes for relation representations
Methods Summary
boxe_kg_arity_position_score
(entity_pos, ...)Perform the BoxE computation at a single arity position.
compute_box
(base, delta, size)Compute the lower and upper corners of a resulting box.
forward
(h, r, t)Evaluate the interaction function.
point_to_box_distance
(points, box_lows, ...)Compute the point to box distance function proposed by [abboud2020] in an element-wise fashion.
product_normalize
(x[, dim])Normalize a tensor along a given dimension so that the geometric mean is 1.0.
Attributes Documentation
- relation_shape: Sequence[str] = ('d', 'd', 's', 'd', 'd', 's')
The symbolic shapes for relation representations
Methods Documentation
- classmethod boxe_kg_arity_position_score(entity_pos: Tensor, other_entity_bump: Tensor, relation_box: tuple[Tensor, Tensor], tanh_map: bool, p: int, power_norm: bool) Tensor [source]
Perform the BoxE computation at a single arity position.
Note
this computation is parallelizable across all positions
Note
entity_pos, other_entity_bump, relation_box_low and relation_box_high have to be in broadcastable shape.
- Parameters:
entity_pos (Tensor) – shape:
(*s_p, d)
This is the base entity position of the entity appearing in the target position. For example, for a fact \(r(h, t)\) and the head arity position, entity_pos is the base position of \(h\).other_entity_bump (Tensor) – shape:
(*s_b, d)
This is the bump of the entity at the other position in the fact. For example, given a fact \(r(h, t)\) and the head arity position, other_entity_bump is the bump of \(t\).relation_box (tuple[Tensor, Tensor]) – shape:
(*s_r, d)
The lower/upper corner of the relation box at the target arity position.tanh_map (bool) – whether to apply the tanh map regularizer
p (int) – The norm used with
torch.linalg.vector_norm()
. Typically is 1 or 2.power_norm (bool) – Whether to use the p-th power of the \(L_p\) norm. It has the advantage of being differentiable around 0, and numerically more stable.
- Returns:
shape:
*s
Arity-position score for the entity relative to the target relation box. Larger is better. The shape is the broadcasted shape from position, bump and box, where the last dimension has been removed.- Return type:
- classmethod compute_box(base: Tensor, delta: Tensor, size: Tensor) tuple[Tensor, Tensor] [source]
Compute the lower and upper corners of a resulting box.
- Parameters:
- Returns:
shape:
(*, d)
each lower and upper bounds of the box whose embeddings are provided as input.- Return type:
- forward(h: tuple[Tensor, Tensor], r: tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor], t: tuple[Tensor, Tensor]) Tensor [source]
Evaluate the interaction function.
See also
Interaction.forward
for a detailed description about the generic batched form of the interaction function.- Parameters:
h (tuple[Tensor, Tensor]) – shape:
(*batch_dims, d)
and(*batch_dims, d)
The head representations.r (tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]) – shape:
(*batch_dims, d)
,(*batch_dims, d)
,(*batch_dims, s)
,(*batch_dims, d)
,(*batch_dims, d)
, and(*batch_dims, s)
The relation representations.t (tuple[Tensor, Tensor]) – shape:
(*batch_dims, d)
and(*batch_dims, d)
The tail representations.
- Returns:
shape:
batch_dims
The scores.- Return type:
- static point_to_box_distance(points: Tensor, box_lows: Tensor, box_highs: Tensor) Tensor [source]
Compute the point to box distance function proposed by [abboud2020] in an element-wise fashion.
- Parameters:
- Returns:
Element-wise distance function scores as per the definition above
Given points \(p\), box_lows \(l\), and box_highs \(h\), the following quantities are defined:
Width \(w\) is the difference between the upper and lower box bound: \(w = h - l\)
Box centers \(c\) are the mean of the box bounds: \(c = (h + l) / 2\)
Finally, the point to box distance \(dist(p,l,h)\) is defined as the following piecewise function:
\[\begin{split}dist(p,l,h) = \begin{cases} |p-c|/(w+1) & l <= p <+ h \\ |p-c|*(w+1) - 0.5*w*((w+1)-1/(w+1)) & otherwise \\ \end{cases}\end{split}\]- Return type: