"""Regularization in PyKEEN."""
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping
from typing import Any, ClassVar
import torch
from class_resolver import ClassResolver, normalize_string
from torch import nn
from torch.nn import functional
from .typing import FloatTensor
from .utils import lp_norm, powersum_norm
__all__ = [
# Base Class
"Regularizer",
# Child classes
"LpRegularizer",
"NoRegularizer",
"CombinedRegularizer",
"PowerSumRegularizer",
"OrthogonalityRegularizer",
"NormLimitRegularizer",
# Utils
"regularizer_resolver",
]
DEFAULT_REGULARIZER_WEIGHT_HPO_RANGE = dict(
weight=dict(type=float, low=0.01, high=1.0, scale="log"),
)
_REGULARIZER_SUFFIX = "Regularizer"
[docs]
class Regularizer(nn.Module, ABC):
"""A base class for all regularizers."""
#: The overall regularization weight
weight: FloatTensor
#: The current regularization term (a scalar)
regularization_term: FloatTensor
#: Should the regularization only be applied once? This was used for ConvKB and defaults to False.
apply_only_once: bool
#: Has this regularizer been updated since last being reset?
updated: bool
#: The default strategy for optimizing the regularizer's hyper-parameters
hpo_default: ClassVar[Mapping[str, Any]] = DEFAULT_REGULARIZER_WEIGHT_HPO_RANGE
def __init__(
self,
weight: float = 1.0,
apply_only_once: bool = False,
parameters: Iterable[nn.Parameter] | None = None,
):
"""Instantiate the regularizer.
:param weight: The relative weight of the regularization
:param apply_only_once: Should the regularization be applied more than once after reset?
:param parameters: Specific parameters to track. if none given, it's expected that your
model automatically delegates to the :func:`update` function.
"""
super().__init__()
self.tracked_parameters = list(parameters) if parameters else []
self.register_buffer(name="weight", tensor=torch.as_tensor(weight))
self.apply_only_once = apply_only_once
self.register_buffer(name="regularization_term", tensor=torch.zeros(1, dtype=torch.float))
self.updated = False
self.reset()
[docs]
@classmethod
def get_normalized_name(cls) -> str:
"""Get the normalized name of the regularizer class."""
return normalize_string(cls.__name__, suffix=_REGULARIZER_SUFFIX)
[docs]
def add_parameter(self, parameter: nn.Parameter) -> None:
"""Add a parameter for regularization."""
self.tracked_parameters.append(parameter)
[docs]
def reset(self) -> None:
"""Reset the regularization term to zero."""
self.regularization_term.detach_().zero_()
self.updated = False
[docs]
@abstractmethod
def forward(self, x: FloatTensor) -> FloatTensor:
"""Compute the regularization term for one tensor."""
raise NotImplementedError
[docs]
def update(self, *tensors: FloatTensor) -> None:
"""Update the regularization term based on passed tensors."""
if not self.training or not torch.is_grad_enabled() or (self.apply_only_once and self.updated):
return
self.regularization_term = self.regularization_term + sum(self.forward(x=x) for x in tensors)
self.updated = True
@property
def term(self) -> FloatTensor:
"""Return the weighted regularization term."""
return self.regularization_term * self.weight
[docs]
def pop_regularization_term(self) -> FloatTensor:
"""Return the weighted regularization term, and reset the regularize afterwards."""
# If there are tracked parameters, update based on them
if self.tracked_parameters:
self.update(*self.tracked_parameters)
result = self.weight * self.regularization_term
self.reset()
return result
[docs]
def post_parameter_update(self):
"""
Reset the regularizer's term.
.. warning ::
Typically, you want to use the regularization term exactly once to calculate gradients via
:meth:`pop_regularization_term`. In this case, there should be no need to manually call this method.
"""
if self.updated:
warnings.warn("Resetting regularization term without using it; this may be an error.", stacklevel=2)
self.reset()
[docs]
class NoRegularizer(Regularizer):
"""A regularizer which does not perform any regularization.
Used to simplify code.
"""
#: The default strategy for optimizing the no-op regularizer's hyper-parameters
hpo_default: ClassVar[Mapping[str, Any]] = {}
# docstr-coverage: inherited
[docs]
def update(self, *tensors: FloatTensor) -> None: # noqa: D102
# no need to compute anything
pass
# docstr-coverage: inherited
[docs]
def forward(self, x: FloatTensor) -> FloatTensor: # noqa: D102
# always return zero
return torch.zeros(1, dtype=x.dtype, device=x.device)
[docs]
class LpRegularizer(Regularizer):
"""A simple L_p norm based regularizer."""
#: The dimension along which to compute the vector-based regularization terms.
dim: int | None
#: Whether to normalize the regularization term by the dimension of the vectors.
#: This allows dimensionality-independent weight tuning.
normalize: bool
def __init__(
self,
*,
# could be moved into kwargs, but needs to stay for experiment integrity check
weight: float = 1.0,
# could be moved into kwargs, but needs to stay for experiment integrity check
apply_only_once: bool = False,
dim: int | None = -1,
normalize: bool = False,
p: float = 2.0,
**kwargs,
):
"""
Initialize the regularizer.
:param weight:
The relative weight of the regularization
:param apply_only_once:
Should the regularization be applied more than once after reset?
:param dim:
the dimension along which to calculate the Lp norm, cf. :func:`lp_norm`
:param normalize:
whether to normalize the norm by the dimension, cf. :func:`lp_norm`
:param p:
the parameter $p$ of the Lp norm, cf. :func:`lp_norm`
:param kwargs:
additional keyword-based parameters passed to :meth:`Regularizer.__init__`
"""
super().__init__(weight=weight, apply_only_once=apply_only_once, **kwargs)
self.dim = dim
self.normalize = normalize
self.p = p
# docstr-coverage: inherited
[docs]
def forward(self, x: FloatTensor) -> FloatTensor: # noqa: D102
return lp_norm(x=x, p=self.p, dim=self.dim, normalize=self.normalize).mean()
[docs]
class PowerSumRegularizer(Regularizer):
"""A simple x^p based regularizer.
Has some nice properties, cf. e.g. https://github.com/pytorch/pytorch/issues/28119.
"""
def __init__(
self,
*,
# could be moved into kwargs, but needs to stay for experiment integrity check
weight: float = 1.0,
# could be moved into kwargs, but needs to stay for experiment integrity check
apply_only_once: bool = False,
dim: int | None = -1,
normalize: bool = False,
p: float = 2.0,
**kwargs,
):
"""
Initialize the regularizer.
:param weight:
The relative weight of the regularization
:param apply_only_once:
Should the regularization be applied more than once after reset?
:param dim:
the dimension along which to calculate the Lp norm, cf. :func:`powersum_norm`
:param normalize:
whether to normalize the norm by the dimension, cf. :func:`powersum_norm`
:param p:
the parameter $p$ of the Lp norm, cf. :func:`powersum_norm`
:param kwargs:
additional keyword-based parameters passed to :meth:`Regularizer.__init__`
"""
super().__init__(weight=weight, apply_only_once=apply_only_once, **kwargs)
self.dim = dim
self.normalize = normalize
self.p = p
# docstr-coverage: inherited
[docs]
def forward(self, x: FloatTensor) -> FloatTensor: # noqa: D102
return powersum_norm(x, p=self.p, dim=self.dim, normalize=self.normalize).mean()
[docs]
class NormLimitRegularizer(Regularizer):
"""A regularizer which formulates a soft constraint on a maximum norm."""
def __init__(
self,
*,
# could be moved into kwargs, but needs to stay for experiment integrity check
weight: float = 1.0,
# could be moved into kwargs, but needs to stay for experiment integrity check
apply_only_once: bool = False,
# regularizer-specific parameters
dim: int | None = -1,
p: float = 2.0,
power_norm: bool = True,
max_norm: float = 1.0,
**kwargs,
):
"""
Initialize the regularizer.
:param weight:
The relative weight of the regularization
:param apply_only_once:
Should the regularization be applied more than once after reset?
:param dim:
the dimension along which to calculate the Lp norm, cf. :func:`powersum_norm`
:param p:
the parameter $p$ of the Lp norm, cf. :func:`powersum_norm`
:param power_norm:
whether to use the $p$ power of the norm instead
:param max_norm:
the maximum norm until which no penalty is added
:param kwargs:
additional keyword-based parameters passed to :meth:`Regularizer.__init__`
"""
super().__init__(weight=weight, apply_only_once=apply_only_once, **kwargs)
self.dim = dim
self.p = p
self.max_norm = max_norm
self.power_norm = power_norm
# docstr-coverage: inherited
[docs]
def forward(self, x: FloatTensor) -> FloatTensor: # noqa: D102
if self.power_norm:
norm = powersum_norm(x, p=self.p, dim=self.dim, normalize=False)
else:
norm = lp_norm(x=x, p=self.p, dim=self.dim, normalize=False)
return (norm - self.max_norm).relu().sum()
[docs]
class OrthogonalityRegularizer(Regularizer):
"""A regularizer for the soft orthogonality constraints from [wang2014]_."""
def __init__(
self,
*,
# could be moved into kwargs, but needs to stay for experiment integrity check
weight: float = 1.0,
# could be moved into kwargs, but needs to stay for experiment integrity check
apply_only_once: bool = True,
epsilon: float = 1e-5,
**kwargs,
):
"""
Initialize the regularizer.
:param weight:
The relative weight of the regularization
:param apply_only_once:
Should the regularization be applied more than once after reset?
:param epsilon:
a small value used to check for approximate orthogonality
:param kwargs:
additional keyword-based parameters passed to :meth:`Regularizer.__init__`
"""
super().__init__(weight=weight, **kwargs, apply_only_once=apply_only_once)
self.epsilon = epsilon
# docstr-coverage: inherited
[docs]
def forward(self, x: FloatTensor) -> FloatTensor: # noqa: D102
raise NotImplementedError(f"{self.__class__.__name__} regularizer is order-sensitive!")
# docstr-coverage: inherited
[docs]
def update(self, *tensors: FloatTensor) -> None: # noqa: D102
if len(tensors) != 2:
raise ValueError("Expects exactly two tensors")
if self.apply_only_once and self.updated:
return
# orthogonality soft constraint: cosine similarity at most epsilon
self.regularization_term = self.regularization_term + (
functional.cosine_similarity(*tensors, dim=-1).pow(2).subtract(self.epsilon).relu().sum()
)
self.updated = True
[docs]
class CombinedRegularizer(Regularizer):
"""A convex combination of regularizers."""
# The normalization factor to balance individual regularizers' contribution.
normalization_factor: FloatTensor
hpo_default = dict(total_weight=dict(type=float, low=0.01, high=1.0, scale="log"), regularizers=tuple())
def __init__(
self,
regularizers: Iterable[Regularizer],
total_weight: float = 1.0,
**kwargs,
):
"""
Initialize the regularizer.
:param regularizers:
the base regularizers
:param total_weight:
the total regularization weight distributed to the base regularizers according to their individual weights
:param kwargs:
additional keyword-based parameters passed to :meth:`Regularizer.__init__`
:raises TypeError:
if any of the regularizers are a no-op regularizer
"""
super().__init__(weight=total_weight, **kwargs)
self.regularizers = nn.ModuleList(regularizers)
for r in self.regularizers:
if isinstance(r, NoRegularizer):
raise TypeError("Can not combine a no-op regularizer")
self.register_buffer(
name="normalization_factor",
tensor=torch.as_tensor(
sum(r.weight for r in self.regularizers),
).reciprocal(),
)
# docstr-coverage: inherited
@property
def normalize(self): # noqa: D102
return any(r.normalize for r in self.regularizers)
# docstr-coverage: inherited
[docs]
def forward(self, x: FloatTensor) -> FloatTensor: # noqa: D102
return self.normalization_factor * sum(r.weight * r.forward(x) for r in self.regularizers)
#: A resolver for regularizers
regularizer_resolver: ClassResolver[Regularizer] = ClassResolver.from_subclasses(
base=Regularizer, default=NoRegularizer, location="pykeen.regularizers.regularizer_resolver"
)