Source code for pykeen.losses

# -*- coding: utf-8 -*-

"""Loss functions implemented in PyKEEN and additionally imported from :class:`torch`.

===============  ==========================================
Name             Reference
===============  ==========================================
bce              :class:`torch.nn.BCELoss`
bceaftersigmoid  :class:`pykeen.losses.BCEAfterSigmoidLoss`
crossentropy     :class:`pykeen.losses.CrossEntropyLoss`
marginranking    :class:`torch.nn.MarginRankingLoss`
mse              :class:`torch.nn.MSELoss`
nssa             :class:`pykeen.losses.NSSALoss`
softplus         :class:`pykeen.losses.SoftplusLoss`
===============  ==========================================

.. note:: This table can be re-generated with ``pykeen ls losses -f rst``

from typing import Any, Mapping, Set, Type, Union

import torch
from torch import nn
from torch.nn import BCELoss, MSELoss, MarginRankingLoss, functional

from .utils import get_cls, normalize_string

__all__ = [

Loss = nn.modules.loss._Loss


[docs]class SoftplusLoss(nn.Module): """A loss function for the softplus.""" def __init__(self, reduction: str = 'mean') -> None: super().__init__() self.reduction = reduction self.softplus = torch.nn.Softplus(beta=1, threshold=20) self._reduction_method = _REDUCTION_METHODS[reduction]
[docs] def forward( self, logits: torch.FloatTensor, labels: torch.FloatTensor, ) -> torch.FloatTensor: """Calculate the loss for the given scores and labels.""" assert 0. <= labels.min() and labels.max() <= 1. # scale labels from [0, 1] to [-1, 1] labels = 2 * labels - 1 loss = self.softplus((-1) * labels * logits) loss = self._reduction_method(loss) return loss
[docs]class BCEAfterSigmoidLoss(nn.Module): """A loss function which uses the numerically unstable version of explicit Sigmoid + BCE.""" def __init__(self, reduction: str = 'mean'): super().__init__() self.reduction = reduction
[docs] def forward( self, logits: torch.FloatTensor, labels: torch.FloatTensor, **kwargs, ) -> torch.FloatTensor: # noqa: D102 post_sigmoid = torch.sigmoid(logits) return functional.binary_cross_entropy(post_sigmoid, labels, **kwargs)
[docs]class CrossEntropyLoss(nn.Module): """Evaluate cross entropy after softmax output.""" def __init__(self, reduction: str = 'mean'): super().__init__() self.reduction = reduction self._reduction_method = _REDUCTION_METHODS[reduction]
[docs] def forward( self, logits: torch.FloatTensor, labels: torch.FloatTensor, **kwargs, ) -> torch.FloatTensor: # noqa: D102 # cross entropy expects a proper probability distribution -> normalize labels p_true = functional.normalize(labels, p=1, dim=-1) # Use numerically stable variant to compute log(softmax) log_p_pred = logits.log_softmax(dim=-1) # compute cross entropy: ce(b) = sum_i p_true(b, i) * log p_pred(b, i) sample_wise_cross_entropy = -(p_true * log_p_pred).sum(dim=-1) return self._reduction_method(sample_wise_cross_entropy)
[docs]class NSSALoss(nn.Module): """An implementation of the self-adversarial negative sampling loss function proposed by [sun2019]_.""" def __init__(self, margin: float, adversarial_temperature: float, reduction: str = 'mean') -> None: super().__init__() self.reduction = reduction self.adversarial_temperature = adversarial_temperature self.margin = margin self._reduction_method = _REDUCTION_METHODS[reduction]
[docs] def forward( self, pos_scores: torch.FloatTensor, neg_scores: torch.FloatTensor, ) -> torch.FloatTensor: """Calculate the loss for the given scores. .. seealso:: """ neg_score_weights = functional.softmax(neg_scores * self.adversarial_temperature, dim=-1).detach() neg_distances = -neg_scores weighted_neg_scores = neg_score_weights * functional.logsigmoid(neg_distances - self.margin) neg_loss = self._reduction_method(weighted_neg_scores) pos_distances = -pos_scores pos_loss = self._reduction_method(functional.logsigmoid(self.margin - pos_distances)) loss = -pos_loss - neg_loss if self._reduction_method is torch.mean: loss = loss / 2. return loss
_LOSS_SUFFIX = 'Loss' _LOSSES: Set[Type[Loss]] = { MarginRankingLoss, BCELoss, SoftplusLoss, BCEAfterSigmoidLoss, CrossEntropyLoss, MSELoss, NSSALoss, } # To add *all* losses implemented in Torch, uncomment: # _LOSSES.update({ # loss # for loss in Loss.__subclasses__() + WeightedLoss.__subclasses__() # if not loss.__name__.startswith('_') # }) losses: Mapping[str, Type[Loss]] = { normalize_string(cls.__name__, suffix=_LOSS_SUFFIX): cls for cls in _LOSSES } losses_hpo_defaults: Mapping[Type[Loss], Mapping[str, Any]] = { MarginRankingLoss: dict( margin=dict(type=int, low=0, high=3, q=1), ), } # Add empty dictionaries as defaults for all remaining losses for cls in _LOSSES: if cls not in losses_hpo_defaults: losses_hpo_defaults[cls] = {}
[docs]def get_loss_cls(query: Union[None, str, Type[Loss]]) -> Type[Loss]: """Get the loss class.""" return get_cls( query, base=nn.Module, lookup_dict=losses, default=MarginRankingLoss, suffix=_LOSS_SUFFIX, )