CompGCNLayer
- class CompGCNLayer(input_dim, output_dim=None, dropout=0.0, use_bias=True, use_relation_bias=False, composition=None, attention_heads=4, attention_dropout=0.1, activation=<class 'torch.nn.modules.linear.Identity'>, activation_kwargs=None, edge_weighting=<class 'pykeen.nn.weighting.SymmetricEdgeWeighting'>)[source]
Bases:
Module
A single layer of the CompGCN model.
Initialize the module.
- Parameters
input_dim (
int
) – The input dimension.output_dim (
Optional
[int
]) – The output dimension. If None, equals the input dimension.dropout (
float
) – The dropout to use for forward and backward edges.use_bias (
bool
) – # TODO: do we really need this? it comes before a mandatory batch norm layer Whether to use bias.use_relation_bias (
bool
) – Whether to use a bias for the relation transformation.composition (
Union
[str
,CompositionModule
,None
]) – The composition function.attention_heads (
int
) – Number of attention heads when using the attention weightingattention_dropout (
float
) – Dropout for the attention message weightingactivation (
Union
[str
,Module
,None
]) – The activation to use.activation_kwargs (
Optional
[Mapping
[str
,Any
]]) – Additional key-word based arguments passed to the activation.edge_weighting (
Union
[str
,Type
[EdgeWeighting
],None
]) – A pre-instantiatedEdgeWeighting
, a class, or name to look up withclass_resolver
.
Methods Summary
forward
(x_e, x_r, edge_index, edge_type)Update entity and relation representations.
message
(x_e, x_r, edge_index, edge_type, weight)Perform message passing.
Reset the model's parameters.
Methods Documentation
- forward(x_e, x_r, edge_index, edge_type)[source]
Update entity and relation representations.
\[X_E'[e] = \frac{1}{3} \left( X_E W_s + \left( \sum_{h,r,e \in T} \alpha(h, e) \phi(X_E[h], X_R[r]) W_f \right) + \left( \sum_{e,r,t \in T} \alpha(e, t) \phi(X_E[t], X_R[r^{-1}]) W_b \right) \right)\]- Parameters
x_e (
FloatTensor
) – shape: (num_entities, input_dim) The entity representations.x_r (
FloatTensor
) – shape: (2 * num_relations, input_dim) The relation representations (including inverse relations).edge_index (
LongTensor
) – shape: (2, num_edges) The edge index, pairs of source and target entity for each triple.edge_type (
LongTensor
) – shape (num_edges,) The edge type, i.e., relation ID, for each triple.
- Return type
Tuple
[FloatTensor
,FloatTensor
]- Returns
shape: (num_entities, output_dim) / (2 * num_relations, output_dim) The updated entity and relation representations.
- message(x_e, x_r, edge_index, edge_type, weight)[source]
Perform message passing.
- Parameters
x_e (
FloatTensor
) – shape: (num_entities, input_dim) The entity representations.x_r (
FloatTensor
) – shape: (2 * num_relations, input_dim) The relation representations (including inverse relations).edge_index (
LongTensor
) – shape: (2, num_edges) The edge index, pairs of source and target entity for each triple.edge_type (
LongTensor
) – shape (num_edges,) The edge type, i.e., relation ID, for each triple.weight (
Parameter
) – The transformation weight.
- Return type
FloatTensor
- Returns
The updated entity representations.