"""Perceptron-like modules."""
from torch import nn
from ..typing import FloatTensor
__all__ = [
"TwoLayerMLP",
"ConcatMLP",
]
[docs]
class TwoLayerMLP(nn.Sequential):
"""A 2-layer MLP with ReLU activation and dropout."""
def __init__(
self,
input_dim: int,
output_dim: int | None = None,
dropout: float = 0.1,
ratio: int | float = 2,
) -> None:
"""Initialize the module.
:param input_dim: the input dimension
:param output_dim: the output dimension. defaults to input dim
:param dropout: the dropout value on the hidden layer
:param ratio: the ratio of the output dimension to the hidden layer size.
"""
output_dim = output_dim or input_dim
hidden_dim = int(ratio * output_dim)
super().__init__(
nn.Linear(input_dim, hidden_dim),
nn.Dropout(dropout),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
[docs]
class ConcatMLP(TwoLayerMLP):
"""A 2-layer MLP with ReLU activation and dropout applied to the flattened token representations.
This is for conveniently choosing a configuration similar to the paper. For more complex aggregation mechanisms,
pass an arbitrary callable instead.
.. seealso::
https://github.com/migalkin/NodePiece/blob/d731c9990/lp_rp/pykeen105/nodepiece_rotate.py#L57-L65
"""
def __init__(
self,
input_dim: int,
output_dim: int | None = None,
dropout: float = 0.1,
ratio: int | float = 2,
flatten_dims: int = 2,
):
"""Initialize the module.
:param input_dim: the input dimension
:param output_dim: the output dimension. defaults to input dim
:param dropout: the dropout value on the hidden layer
:param ratio: the ratio of the output dimension to the hidden layer size.
:param flatten_dims: the number of trailing dimensions to flatten
"""
super().__init__(input_dim=input_dim, output_dim=output_dim, dropout=dropout, ratio=ratio)
self.flatten_dims = flatten_dims
[docs]
def forward(self, xs: FloatTensor, dim: int) -> FloatTensor:
"""Forward the MLP on the given dimension.
:param xs: The tensor to forward
:param dim: Only a parameter to match the signature of :func:`torch.mean` / :func:`torch.sum` this class is not
thought to be usable from outside
:returns: The tensor after applying this MLP
"""
assert dim == -2
return super().forward(xs.view(*xs.shape[: -self.flatten_dims], -1))