# -*- coding: utf-8 -*-
"""Regularization in PyKEEN."""
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Iterable, Mapping, Optional
import torch
from class_resolver import ClassResolver, normalize_string
from torch import nn
from torch.nn import functional
from .utils import lp_norm, powersum_norm
__all__ = [
# Base Class
"Regularizer",
# Child classes
"LpRegularizer",
"NoRegularizer",
"CombinedRegularizer",
"PowerSumRegularizer",
"OrthogonalityRegularizer",
"NormLimitRegularizer",
# Utils
"regularizer_resolver",
]
_REGULARIZER_SUFFIX = "Regularizer"
[docs]class Regularizer(nn.Module, ABC):
"""A base class for all regularizers."""
#: The overall regularization weight
weight: torch.FloatTensor
#: The current regularization term (a scalar)
regularization_term: torch.FloatTensor
#: Should the regularization only be applied once? This was used for ConvKB and defaults to False.
apply_only_once: bool
#: Has this regularizer been updated since last being reset?
updated: bool
#: The default strategy for optimizing the regularizer's hyper-parameters
hpo_default: ClassVar[Mapping[str, Any]]
def __init__(
self,
weight: float = 1.0,
apply_only_once: bool = False,
parameters: Optional[Iterable[nn.Parameter]] = None,
):
"""Instantiate the regularizer.
:param weight: The relative weight of the regularization
:param apply_only_once: Should the regularization be applied more than once after reset?
:param parameters: Specific parameters to track. if none given, it's expected that your
model automatically delegates to the :func:`update` function.
"""
super().__init__()
self.tracked_parameters = list(parameters) if parameters else []
self.register_buffer(name="weight", tensor=torch.as_tensor(weight))
self.apply_only_once = apply_only_once
self.register_buffer(name="regularization_term", tensor=torch.zeros(1, dtype=torch.float))
self.updated = False
self.reset()
[docs] @classmethod
def get_normalized_name(cls) -> str:
"""Get the normalized name of the regularizer class."""
return normalize_string(cls.__name__, suffix=_REGULARIZER_SUFFIX)
[docs] def add_parameter(self, parameter: nn.Parameter) -> None:
"""Add a parameter for regularization."""
self.tracked_parameters.append(parameter)
[docs] def reset(self) -> None:
"""Reset the regularization term to zero."""
self.regularization_term.detach_().zero_()
self.updated = False
[docs] @abstractmethod
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""Compute the regularization term for one tensor."""
raise NotImplementedError
[docs] def update(self, *tensors: torch.FloatTensor) -> None:
"""Update the regularization term based on passed tensors."""
if not self.training or not torch.is_grad_enabled() or (self.apply_only_once and self.updated):
return
self.regularization_term = self.regularization_term + sum(self.forward(x=x) for x in tensors)
self.updated = True
@property
def term(self) -> torch.FloatTensor:
"""Return the weighted regularization term."""
return self.regularization_term * self.weight
[docs] def pop_regularization_term(self) -> torch.FloatTensor:
"""Return the weighted regularization term, and reset the regularize afterwards."""
# If there are tracked parameters, update based on them
if self.tracked_parameters:
self.update(*self.tracked_parameters)
result = self.weight * self.regularization_term
self.reset()
return result
[docs] def post_parameter_update(self):
"""
Reset the regularizer's term.
.. warning ::
Typically, you want to use the regularization term exactly once to calculate gradients via
:meth:`pop_regularization_term`. In this case, there should be no need to manually call this method.
"""
if self.updated:
warnings.warn("Resetting regularization term without using it; this may be an error.")
self.reset()
[docs]class NoRegularizer(Regularizer):
"""A regularizer which does not perform any regularization.
Used to simplify code.
"""
#: The default strategy for optimizing the no-op regularizer's hyper-parameters
hpo_default: ClassVar[Mapping[str, Any]] = {}
# docstr-coverage: inherited
[docs] def update(self, *tensors: torch.FloatTensor) -> None: # noqa: D102
# no need to compute anything
pass
# docstr-coverage: inherited
[docs] def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102
# always return zero
return torch.zeros(1, dtype=x.dtype, device=x.device)
[docs]class LpRegularizer(Regularizer):
"""A simple L_p norm based regularizer."""
#: The dimension along which to compute the vector-based regularization terms.
dim: Optional[int]
#: Whether to normalize the regularization term by the dimension of the vectors.
#: This allows dimensionality-independent weight tuning.
normalize: bool
#: The default strategy for optimizing the LP regularizer's hyper-parameters
hpo_default: ClassVar[Mapping[str, Any]] = dict(
weight=dict(type=float, low=0.01, high=1.0, scale="log"),
)
def __init__(
self,
*,
# could be moved into kwargs, but needs to stay for experiment integrity check
weight: float = 1.0,
# could be moved into kwargs, but needs to stay for experiment integrity check
apply_only_once: bool = False,
dim: Optional[int] = -1,
normalize: bool = False,
p: float = 2.0,
**kwargs,
):
"""
Initialize the regularizer.
:param weight:
The relative weight of the regularization
:param apply_only_once:
Should the regularization be applied more than once after reset?
:param dim:
the dimension along which to calculate the Lp norm, cf. :func:`lp_norm`
:param normalize:
whether to normalize the norm by the dimension, cf. :func:`lp_norm`
:param p:
the parameter $p$ of the Lp norm, cf. :func:`lp_norm`
:param kwargs:
additional keyword-based parameters passed to :meth:`Regularizer.__init__`
"""
super().__init__(weight=weight, apply_only_once=apply_only_once, **kwargs)
self.dim = dim
self.normalize = normalize
self.p = p
# docstr-coverage: inherited
[docs] def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102
return lp_norm(x=x, p=self.p, dim=self.dim, normalize=self.normalize).mean()
[docs]class PowerSumRegularizer(Regularizer):
"""A simple x^p based regularizer.
Has some nice properties, cf. e.g. https://github.com/pytorch/pytorch/issues/28119.
"""
#: The default strategy for optimizing the power sum regularizer's hyper-parameters
hpo_default: ClassVar[Mapping[str, Any]] = dict(
weight=dict(type=float, low=0.01, high=1.0, scale="log"),
)
def __init__(
self,
*,
# could be moved into kwargs, but needs to stay for experiment integrity check
weight: float = 1.0,
# could be moved into kwargs, but needs to stay for experiment integrity check
apply_only_once: bool = False,
dim: Optional[int] = -1,
normalize: bool = False,
p: float = 2.0,
**kwargs,
):
"""
Initialize the regularizer.
:param weight:
The relative weight of the regularization
:param apply_only_once:
Should the regularization be applied more than once after reset?
:param dim:
the dimension along which to calculate the Lp norm, cf. :func:`powersum_norm`
:param normalize:
whether to normalize the norm by the dimension, cf. :func:`powersum_norm`
:param p:
the parameter $p$ of the Lp norm, cf. :func:`powersum_norm`
:param kwargs:
additional keyword-based parameters passed to :meth:`Regularizer.__init__`
"""
super().__init__(weight=weight, apply_only_once=apply_only_once, **kwargs)
self.dim = dim
self.normalize = normalize
self.p = p
# docstr-coverage: inherited
[docs] def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102
return powersum_norm(x, p=self.p, dim=self.dim, normalize=self.normalize).mean()
[docs]class NormLimitRegularizer(Regularizer):
"""A regularizer which formulates a soft constraint on a maximum norm."""
def __init__(
self,
*,
# could be moved into kwargs, but needs to stay for experiment integrity check
weight: float = 1.0,
# could be moved into kwargs, but needs to stay for experiment integrity check
apply_only_once: bool = False,
# regularizer-specific parameters
dim: Optional[int] = -1,
p: float = 2.0,
power_norm: bool = True,
max_norm: float = 1.0,
**kwargs,
):
"""
Initialize the regularizer.
:param weight:
The relative weight of the regularization
:param apply_only_once:
Should the regularization be applied more than once after reset?
:param dim:
the dimension along which to calculate the Lp norm, cf. :func:`powersum_norm`
:param p:
the parameter $p$ of the Lp norm, cf. :func:`powersum_norm`
:param power_norm:
whether to use the $p$ power of the norm instead
:param max_norm:
the maximum norm until which no penalty is added
:param kwargs:
additional keyword-based parameters passed to :meth:`Regularizer.__init__`
"""
super().__init__(weight=weight, apply_only_once=apply_only_once, **kwargs)
self.dim = dim
self.p = p
self.max_norm = max_norm
self.power_norm = power_norm
# docstr-coverage: inherited
[docs] def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102
if self.power_norm:
norm = powersum_norm(x, p=self.p, dim=self.dim, normalize=False)
else:
norm = lp_norm(x=x, p=self.p, dim=self.dim, normalize=False)
return (norm - self.max_norm).relu().sum()
[docs]class OrthogonalityRegularizer(Regularizer):
"""A regularizer for the soft orthogonality constraints from [wang2014]_."""
#: The default strategy for optimizing the TransH regularizer's hyper-parameters
hpo_default: ClassVar[Mapping[str, Any]] = dict(
weight=dict(type=float, low=0.01, high=1.0, scale="log"),
)
def __init__(
self,
*,
# could be moved into kwargs, but needs to stay for experiment integrity check
weight: float = 1.0,
# could be moved into kwargs, but needs to stay for experiment integrity check
apply_only_once: bool = True,
epsilon: float = 1e-5,
**kwargs,
):
"""
Initialize the regularizer.
:param weight:
The relative weight of the regularization
:param apply_only_once:
Should the regularization be applied more than once after reset?
:param epsilon:
a small value used to check for approximate orthogonality
:param kwargs:
additional keyword-based parameters passed to :meth:`Regularizer.__init__`
"""
super().__init__(weight=weight, **kwargs, apply_only_once=apply_only_once)
self.epsilon = epsilon
# docstr-coverage: inherited
[docs] def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102
raise NotImplementedError(f"{self.__class__.__name__} regularizer is order-sensitive!")
# docstr-coverage: inherited
[docs] def update(self, *tensors: torch.FloatTensor) -> None: # noqa: D102
if len(tensors) != 2:
raise ValueError("Expects exactly two tensors")
if self.apply_only_once and self.updated:
return
# orthogonality soft constraint: cosine similarity at most epsilon
self.regularization_term = self.regularization_term + (
functional.cosine_similarity(*tensors, dim=-1).pow(2).subtract(self.epsilon).relu().sum()
)
self.updated = True
[docs]class CombinedRegularizer(Regularizer):
"""A convex combination of regularizers."""
# The normalization factor to balance individual regularizers' contribution.
normalization_factor: torch.FloatTensor
def __init__(
self,
regularizers: Iterable[Regularizer],
total_weight: float = 1.0,
**kwargs,
):
"""
Initialize the regularizer.
:param regularizers:
the base regularizers
:param total_weight:
the total regularization weight distributed to the base regularizers according to their individual weights
:param kwargs:
additional keyword-based parameters passed to :meth:`Regularizer.__init__`
:raises TypeError:
if any of the regularizers are a no-op regularizer
"""
super().__init__(weight=total_weight, **kwargs)
self.regularizers = nn.ModuleList(regularizers)
for r in self.regularizers:
if isinstance(r, NoRegularizer):
raise TypeError("Can not combine a no-op regularizer")
self.register_buffer(
name="normalization_factor",
tensor=torch.as_tensor(
sum(r.weight for r in self.regularizers),
).reciprocal(),
)
# docstr-coverage: inherited
@property
def normalize(self): # noqa: D102
return any(r.normalize for r in self.regularizers)
# docstr-coverage: inherited
[docs] def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102
return self.normalization_factor * sum(r.weight * r.forward(x) for r in self.regularizers)
regularizer_resolver: ClassResolver[Regularizer] = ClassResolver.from_subclasses(
base=Regularizer,
default=NoRegularizer,
)