FocalLoss
- class FocalLoss(*, gamma: float = 2.0, alpha: float | None = None, **kwargs)[source]
Bases:
PointwiseLoss
The focal loss proposed by [lin2018].
It is an adaptation of the (binary) cross entropy loss, which deals better with imbalanced data. The implementation is strongly inspired by the implementation in
torchvision.ops.sigmoid_focal_loss()
, except it is using a module rather than the functional form.The loss is given as
\[FL(p_t) = -(1 - p_t)^\gamma \log (p_t)\]with \(p_t = y \cdot p + (1 - y) \cdot (1 - p)\), where \(p\) refers to the predicted probability, and y to the ground truth label in \({0, 1}\).
Focal loss has some other nice properties, e.g., better calibrated predicted probabilities. See [mukhoti2020].
Initialize the loss module.
- Parameters:
gamma (float) – >= 0 Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Setting gamma > 0 reduces the relative loss for well-classified examples. The default value of 2 is taken from [lin2018], which report this setting to work best for their experiments. However, these experiments where conducted on the task of object classification in images, so take it with a grain of salt.
alpha (float | None) – Weighting factor in range (0, 1) to balance positive vs negative examples. alpha is the weight for the positive class, i.e., increasing it will let the loss focus more on this class. The weight for the negative class is obtained as 1 - alpha. [lin2018] recommends to either set this to the inverse class frequency, or treat it as a hyper-parameter.
kwargs – Additional keyword-based arguments passed to
pykeen.losses.PointwiseLoss
.
- Raises:
ValueError – If alpha is in the wrong range
Methods Summary
forward
(prediction, labels)Define the computation performed at every call.
Methods Documentation
- forward(prediction: Tensor, labels: Tensor) Tensor [source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.