Source code for pykeen.nn.perceptron

# -*- coding: utf-8 -*-

"""Perceptron-like modules."""

from typing import Union

import torch
from torch import nn

__all__ = [
    "ConcatMLP",
]


[docs]class ConcatMLP(nn.Sequential): """A 2-layer MLP with ReLU activation and dropout applied to the concatenation of token representations. This is for conveniently choosing a configuration similar to the paper. For more complex aggregation mechanisms, pass an arbitrary callable instead. .. seealso:: https://github.com/migalkin/NodePiece/blob/d731c9990/lp_rp/pykeen105/nodepiece_rotate.py#L57-L65 """ def __init__( self, num_tokens: int, embedding_dim: int, dropout: float = 0.1, ratio: Union[int, float] = 2, ): """Initialize the module. :param num_tokens: the number of tokens :param embedding_dim: the embedding dimension for a single token :param dropout: the dropout value on the hidden layer :param ratio: the ratio of the embedding dimension to the hidden layer size. """ hidden_dim = int(ratio * embedding_dim) super().__init__( nn.Linear(num_tokens * embedding_dim, hidden_dim), nn.Dropout(dropout), nn.ReLU(), nn.Linear(hidden_dim, embedding_dim), )
[docs] def forward(self, xs: torch.FloatTensor, dim: int) -> torch.FloatTensor: """Forward the MLP on the given dimension. :param xs: The tensor to forward :param dim: Only a parameter to match the signature of torch.mean / torch.sum this class is not thought to be usable from outside :returns: The tensor after applying this MLP """ assert dim == -2 return super().forward(xs.view(*xs.shape[:-2], -1))