Loss Functions

Loss functions integrated in PyKEEN.

Rather than re-using the built-in loss functions in PyTorch, we have elected to re-implement some of the code from pytorch.nn.modules.loss in order to encode the three different links of loss functions accepted by PyKEEN in a class hierarchy. This allows for PyKEEN to more dynamically handle different kinds of loss functions as well as share code. Further, it gives more insight to potential users.

Throughout the following explanations of pointwise loss functions, pairwise loss functions, and setwise loss functions, we will assume the set of entities \(\mathcal{E}\), set of relations \(\mathcal{R}\), set of possible triples \(\mathcal{T} = \mathcal{E} \times \mathcal{R} \times \mathcal{E}\), set of possible subsets of possible triples \(2^{\mathcal{T}}\) (i.e., the power set of \(\mathcal{T}\)), set of positive triples \(\mathcal{K}\), set of negative triples \(\mathcal{\bar{K}}\), scoring function (e.g., TransE) \(f: \mathcal{T} \rightarrow \mathbb{R}\) and labeling function \(l:\mathcal{T} \rightarrow \{0,1\}\) where a value of 1 denotes the triple is positive (i.e., \((h,r,t) \in \mathcal{K}\)) and a value of 0 denotes the triple is negative (i.e., \((h,r,t) \notin \mathcal{K}\)).

Note

In most realistic use cases of knowledge graph embedding models, you will have observed a subset of positive triples \(\mathcal{T_{obs}} \subset \mathcal{K}\) and no observations over negative triples. Depending on the training assumption (sLCWA or LCWA), this will mean negative triples are generated in a variety of patterns.

Note

Following the open world assumption (OWA), triples \(\mathcal{\bar{K}}\) are better named “not positive” rather than negative. This is most relevant for pointwise loss functions. For pairwise and setwise loss functions, triples are compared as being more/less positive and the binary classification is not relevant.

Pointwise Loss Functions

A pointwise loss is applied to a single triple. It takes the form of \(L: \mathcal{T} \rightarrow \mathbb{R}\) and computes a real-value for the triple given its labeling. Typically, a pointwise loss function takes the form of \(g: \mathbb{R} \times \{0,1\} \rightarrow \mathbb{R}\) based on the scoring function and labeling function.

\[L(k) = g(f(k), l(k))\]

Examples

Pointwise Loss

Formulation

Square Error

\(g(s, l) = \frac{1}{2}(s - l)^2\)

Binary Cross Entropy

\(g(s, l) = -(l*\log (\sigma(s))+(1-l)*(\log (1-\sigma(s))))\)

Pointwise Hinge

\(g(s, l) = \max(0, \lambda -\hat{l}*s)\)

Soft Pointwise Hinge

\(g(s, l) = \log(1+\exp(\lambda-\hat{l}*s))\)

Pointwise Logistic (softplus)

\(g(s, l) = \log(1+\exp(-\hat{l}*s))\)

For the pointwise logistic and pointwise hinge losses, \(\hat{l}\) has been rescaled from \(\{0,1\}\) to \(\{-1,1\}\). The sigmoid logistic loss function is defined as \(\sigma(z) = \frac{1}{1 + e^{-z}}\).

Note

The pointwise logistic loss can be considered as a special case of the pointwise soft hinge loss where \(\lambda = 0\).

Batching

The pointwise loss of a set of triples (i.e., a batch) \(\mathcal{L}_L: 2^{\mathcal{T}} \rightarrow \mathbb{R}\) is defined as the arithmetic mean of the pointwise losses over each triple in the subset \(\mathcal{B} \in 2^{\mathcal{T}}\):

\[\mathcal{L}_L(\mathcal{B}) = \frac{1}{|\mathcal{B}|} \sum \limits_{k \in \mathcal{B}} L(k)\]

Pairwise Loss Functions

A pairwise loss is applied to a pair of triples - a positive and a negative one. It is defined as \(L: \mathcal{K} \times \mathcal{\bar{K}} \rightarrow \mathbb{R}\) and computes a real value for the pair.

All loss functions implemented in PyKEEN induce an auxillary loss function based on the chosen interaction function \(L{*}: \mathbb{R} \times \mathbb{R} \rightarrow \mathbb{R}\) that simply passes the scores through. Note that \(L\) is often used interchangbly with \(L^{*}\).

\[L(k, \bar{k}) = L^{*}(f(k), f(\bar{k}))\]

Delta Pairwise Loss Functions

Delta pairwise losses are computed on the differences between the scores of the negative and positive triples (e.g., \(\Delta := f(\bar{k}) - f(k)\)) with transfer function \(g: \mathbb{R} \rightarrow \mathbb{R}\) that take the form of:

\[L^{*}(f(k), f(\bar{k})) = g(f(\bar{k}) - f(k)) := g(\Delta)\]

The following table shows delta pairwise loss functions:

Pairwise Loss

Activation

Margin

Formulation

Pairwise Hinge (margin ranking)

ReLU

\(\lambda \neq 0\)

\(g(\Delta) = \max(0, \Delta + \lambda)\)

Soft Pairwise Hinge (soft margin ranking)

softplus

\(\lambda \neq 0\)

\(g(\Delta) = \log(1 + \exp(\Delta + \lambda))\)

Pairwise Logistic

softplus

\(\lambda=0\)

\(g(\Delta) = \log(1 + \exp(\Delta))\)

Note

The pairwise logistic loss can be considered as a special case of the pairwise soft hinge loss where \(\lambda = 0\).

Inseparable Pairwise Loss Functions

The following pairwise loss function use the full generalized form of \(L(k, \bar{k}) = \dots\) for their definitions:

Pairwise Loss

Formulation

Double Loss

\(h(\bar{\lambda} + f(\bar{k})) + h(\lambda - f(k))\)

Batching

The pairwise loss for a set of pairs of positive/negative triples \(\mathcal{L}_L: 2^{\mathcal{K} \times \mathcal{\bar{K}}} \rightarrow \mathbb{R}\) is defined as the arithmetic mean of the pairwise losses for each pair of positive and negative triples in the subset \(\mathcal{B} \in 2^{\mathcal{K} \times \mathcal{\bar{K}}}\).

\[\mathcal{L}_L(\mathcal{B}) = \frac{1}{|\mathcal{B}|} \sum \limits_{(k, \bar{k}) \in \mathcal{B}} L(k, \bar{k})\]

Setwise Loss Functions

A setwise loss is applied to a set of triples which can be either positive or negative. It is defined as \(L: 2^{\mathcal{T}} \rightarrow \mathbb{R}\). The two setwise loss functions implemented in PyKEEN, pykeen.losses.NSSALoss and pykeen.losses.CrossEntropyLoss are both widely different in their paradigms, but both share the notion that triples are not strictly positive or negative.

\[L(k_1, ... k_n) = g(f(k_1), ..., f(k_n))\]

Batching

The pairwise loss for a set of sets of triples triples \(\mathcal{L}_L: 2^{2^{\mathcal{T}}} \rightarrow \mathbb{R}\) is defined as the arithmetic mean of the setwise losses for each set of triples \(\mathcal{b}\) in the subset \(\mathcal{B} \in 2^{2^{\mathcal{T}}}\).

\[\mathcal{L}_L(\mathcal{B}) = \frac{1}{|\mathcal{B}|} \sum \limits_{\mathcal{b} \in \mathcal{B}} L(\mathcal{b})\]

Classes

PointwiseLoss([reduction])

Pointwise loss functions compute an independent loss term for each triple-label pair.

DeltaPointwiseLoss([margin, ...])

A generic class for delta-pointwise losses.

MarginPairwiseLoss([margin, ...])

The generalized margin ranking loss.

PairwiseLoss([reduction])

Pairwise loss functions compare the scores of a positive triple and a negative triple.

SetwiseLoss([reduction])

Setwise loss functions compare the scores of several triples.

AdversarialLoss([...])

A loss with adversarial weighting of negative samples.

AdversarialBCEWithLogitsLoss([...])

An adversarially weighted BCE loss.

BCEAfterSigmoidLoss([reduction])

The numerically unstable version of explicit Sigmoid + BCE loss.

BCEWithLogitsLoss([reduction])

The binary cross entropy loss.

CrossEntropyLoss([reduction])

The cross entropy loss that evaluates the cross entropy after softmax output.

FocalLoss(*[, gamma, alpha])

The focal loss proposed by [lin2018].

InfoNCELoss([margin, ...])

The InfoNCE loss with additive margin proposed by [wang2022].

MarginRankingLoss([margin, reduction])

The pairwise hinge loss (i.e., margin ranking loss).

MSELoss([reduction])

The mean squared error loss.

NSSALoss([margin, adversarial_temperature, ...])

The self-adversarial negative sampling loss function proposed by [sun2019].

SoftplusLoss([reduction])

The pointwise logistic loss (i.e., softplus loss).

SoftPointwiseHingeLoss([margin, reduction])

The soft pointwise hinge loss.

PointwiseHingeLoss([margin, reduction])

The pointwise hinge loss.

DoubleMarginLoss(*[, positive_margin, ...])

A limit-based scoring loss, with separate margins for positive and negative elements from [sun2018].

SoftMarginRankingLoss([margin, reduction])

The soft pairwise hinge loss (i.e., soft margin ranking loss).

PairwiseLogisticLoss([reduction])

The pairwise logistic loss.

Class Inheritance Diagram

digraph inheritance6b847db012 { bgcolor=transparent; rankdir=LR; size="8.0, 12.0"; "AdversarialBCEWithLogitsLoss" [URL="../api/pykeen.losses.AdversarialBCEWithLogitsLoss.html#pykeen.losses.AdversarialBCEWithLogitsLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="An adversarially weighted BCE loss."]; "AdversarialLoss" -> "AdversarialBCEWithLogitsLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "AdversarialLoss" [URL="../api/pykeen.losses.AdversarialLoss.html#pykeen.losses.AdversarialLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="A loss with adversarial weighting of negative samples."]; "SetwiseLoss" -> "AdversarialLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "BCEAfterSigmoidLoss" [URL="../api/pykeen.losses.BCEAfterSigmoidLoss.html#pykeen.losses.BCEAfterSigmoidLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The numerically unstable version of explicit Sigmoid + BCE loss."]; "PointwiseLoss" -> "BCEAfterSigmoidLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "BCEWithLogitsLoss" [URL="../api/pykeen.losses.BCEWithLogitsLoss.html#pykeen.losses.BCEWithLogitsLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The binary cross entropy loss."]; "PointwiseLoss" -> "BCEWithLogitsLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "CrossEntropyLoss" [URL="../api/pykeen.losses.CrossEntropyLoss.html#pykeen.losses.CrossEntropyLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The cross entropy loss that evaluates the cross entropy after softmax output."]; "SetwiseLoss" -> "CrossEntropyLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "DeltaPointwiseLoss" [URL="../api/pykeen.losses.DeltaPointwiseLoss.html#pykeen.losses.DeltaPointwiseLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="A generic class for delta-pointwise losses."]; "PointwiseLoss" -> "DeltaPointwiseLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "DoubleMarginLoss" [URL="../api/pykeen.losses.DoubleMarginLoss.html#pykeen.losses.DoubleMarginLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="A limit-based scoring loss, with separate margins for positive and negative elements from [sun2018]_."]; "PointwiseLoss" -> "DoubleMarginLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "FocalLoss" [URL="../api/pykeen.losses.FocalLoss.html#pykeen.losses.FocalLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The focal loss proposed by [lin2018]_."]; "PointwiseLoss" -> "FocalLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "InfoNCELoss" [URL="../api/pykeen.losses.InfoNCELoss.html#pykeen.losses.InfoNCELoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The InfoNCE loss with additive margin proposed by [wang2022]_."]; "CrossEntropyLoss" -> "InfoNCELoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "Loss" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",tooltip="A loss function."]; "_Loss" -> "Loss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "MSELoss" [URL="../api/pykeen.losses.MSELoss.html#pykeen.losses.MSELoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The mean squared error loss."]; "PointwiseLoss" -> "MSELoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "MarginPairwiseLoss" [URL="../api/pykeen.losses.MarginPairwiseLoss.html#pykeen.losses.MarginPairwiseLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The generalized margin ranking loss."]; "PairwiseLoss" -> "MarginPairwiseLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "MarginRankingLoss" [URL="../api/pykeen.losses.MarginRankingLoss.html#pykeen.losses.MarginRankingLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The pairwise hinge loss (i.e., margin ranking loss)."]; "MarginPairwiseLoss" -> "MarginRankingLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "Module" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",tooltip="Base class for all neural network modules."]; "NSSALoss" [URL="../api/pykeen.losses.NSSALoss.html#pykeen.losses.NSSALoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The self-adversarial negative sampling loss function proposed by [sun2019]_."]; "AdversarialLoss" -> "NSSALoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "PairwiseLogisticLoss" [URL="../api/pykeen.losses.PairwiseLogisticLoss.html#pykeen.losses.PairwiseLogisticLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The pairwise logistic loss."]; "SoftMarginRankingLoss" -> "PairwiseLogisticLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "PairwiseLoss" [URL="../api/pykeen.losses.PairwiseLoss.html#pykeen.losses.PairwiseLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="Pairwise loss functions compare the scores of a positive triple and a negative triple."]; "Loss" -> "PairwiseLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "PointwiseHingeLoss" [URL="../api/pykeen.losses.PointwiseHingeLoss.html#pykeen.losses.PointwiseHingeLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The pointwise hinge loss."]; "DeltaPointwiseLoss" -> "PointwiseHingeLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "PointwiseLoss" [URL="../api/pykeen.losses.PointwiseLoss.html#pykeen.losses.PointwiseLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="Pointwise loss functions compute an independent loss term for each triple-label pair."]; "Loss" -> "PointwiseLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "SetwiseLoss" [URL="../api/pykeen.losses.SetwiseLoss.html#pykeen.losses.SetwiseLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="Setwise loss functions compare the scores of several triples."]; "Loss" -> "SetwiseLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "SoftMarginRankingLoss" [URL="../api/pykeen.losses.SoftMarginRankingLoss.html#pykeen.losses.SoftMarginRankingLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The soft pairwise hinge loss (i.e., soft margin ranking loss)."]; "MarginPairwiseLoss" -> "SoftMarginRankingLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "SoftPointwiseHingeLoss" [URL="../api/pykeen.losses.SoftPointwiseHingeLoss.html#pykeen.losses.SoftPointwiseHingeLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The soft pointwise hinge loss."]; "DeltaPointwiseLoss" -> "SoftPointwiseHingeLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "SoftplusLoss" [URL="../api/pykeen.losses.SoftplusLoss.html#pykeen.losses.SoftplusLoss",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="The pointwise logistic loss (i.e., softplus loss)."]; "SoftPointwiseHingeLoss" -> "SoftplusLoss" [arrowsize=0.5,style="setlinewidth(0.5)"]; "_Loss" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled"]; "Module" -> "_Loss" [arrowsize=0.5,style="setlinewidth(0.5)"]; }