Source code for pykeen.regularizers

# -*- coding: utf-8 -*-

"""Regularization in PyKEEN."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, ClassVar, Collection, Iterable, Mapping, Optional, Type, Union

import torch
from torch import nn
from torch.nn import functional

from .nn.norm import lp_norm, powersum_norm
from .utils import get_cls, normalize_string

__all__ = [
    'Regularizer',
    'LpRegularizer',
    'NoRegularizer',
    'CombinedRegularizer',
    'PowerSumRegularizer',
    'TransHRegularizer',
    'get_regularizer_cls',
]

_REGULARIZER_SUFFIX = 'Regularizer'


[docs]class Regularizer(nn.Module, ABC): """A base class for all regularizers.""" #: The overall regularization weight weight: torch.FloatTensor #: The current regularization term (a scalar) regularization_term: torch.FloatTensor #: Should the regularization only be applied once? This was used for ConvKB and defaults to False. apply_only_once: bool #: Has this regularizer been updated since last being reset? updated: bool #: The default strategy for optimizing the regularizer's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] def __init__( self, weight: float = 1.0, apply_only_once: bool = False, parameters: Optional[Iterable[nn.Parameter]] = None, ): """Instantiate the regularizer. :param weight: The relative weight of the regularization :param apply_only_once: Should the regularization be applied more than once after reset? :param parameters: Specific parameters to track. if none given, it's expected that your model automatically delegates to the :func:`update` function. """ super().__init__() self.tracked_parameters = list(parameters) if parameters else [] self.register_buffer(name='weight', tensor=torch.as_tensor(weight)) self.apply_only_once = apply_only_once self.register_buffer(name="regularization_term", tensor=torch.zeros(1, dtype=torch.float)) self.updated = False self.reset()
[docs] @classmethod def get_normalized_name(cls) -> str: """Get the normalized name of the regularizer class.""" return normalize_string(cls.__name__, suffix=_REGULARIZER_SUFFIX)
[docs] def add_parameter(self, parameter: nn.Parameter) -> None: """Add a parameter for regularization.""" self.tracked_parameters.append(parameter)
[docs] def reset(self) -> None: """Reset the regularization term to zero.""" self.regularization_term.detach_().zero_() self.updated = False
[docs] @abstractmethod def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: """Compute the regularization term for one tensor.""" raise NotImplementedError
[docs] def update(self, *tensors: torch.FloatTensor) -> None: """Update the regularization term based on passed tensors.""" if not self.training or not torch.is_grad_enabled() or (self.apply_only_once and self.updated): return self.regularization_term = self.regularization_term + sum(self.forward(x=x) for x in tensors) self.updated = True
@property def term(self) -> torch.FloatTensor: """Return the weighted regularization term.""" return self.regularization_term * self.weight
[docs] def pop_regularization_term(self) -> torch.FloatTensor: """Return the weighted regularization term, and reset the regularize afterwards.""" # If there are tracked parameters, update based on them if self.tracked_parameters: self.update(*self.tracked_parameters) term = self.regularization_term self.reset() return self.weight * term
[docs]class NoRegularizer(Regularizer): """A regularizer which does not perform any regularization. Used to simplify code. """ #: The default strategy for optimizing the no-op regularizer's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = {}
[docs] def update(self, *tensors: torch.FloatTensor) -> None: # noqa: D102 # no need to compute anything pass
[docs] def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102 # always return zero return torch.zeros(1, dtype=x.dtype, device=x.device)
[docs]class LpRegularizer(Regularizer): """A simple L_p norm based regularizer.""" #: The dimension along which to compute the vector-based regularization terms. dim: Optional[int] #: Whether to normalize the regularization term by the dimension of the vectors. #: This allows dimensionality-independent weight tuning. normalize: bool #: The default strategy for optimizing the LP regularizer's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( weight=dict(type=float, low=0.01, high=1.0, scale='log'), ) def __init__( self, weight: float = 1.0, dim: Optional[int] = -1, normalize: bool = False, p: float = 2., apply_only_once: bool = False, parameters: Optional[Iterable[nn.Parameter]] = None, ): super().__init__(weight=weight, apply_only_once=apply_only_once, parameters=parameters) self.dim = dim self.normalize = normalize self.p = p
[docs] def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102 return lp_norm(x=x, p=self.p, dim=self.dim, normalize=self.normalize).mean()
[docs]class PowerSumRegularizer(Regularizer): """A simple x^p based regularizer. Has some nice properties, cf. e.g. https://github.com/pytorch/pytorch/issues/28119. """ #: The default strategy for optimizing the power sum regularizer's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( weight=dict(type=float, low=0.01, high=1.0, scale='log'), ) def __init__( self, weight: float = 1.0, dim: Optional[int] = -1, normalize: bool = False, p: float = 2., apply_only_once: bool = False, parameters: Optional[Iterable[nn.Parameter]] = None, ): super().__init__(weight=weight, apply_only_once=apply_only_once, parameters=parameters) self.dim = dim self.normalize = normalize self.p = p
[docs] def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102 return powersum_norm(x, p=self.p, dim=self.dim, normalize=self.normalize).mean()
[docs]class TransHRegularizer(Regularizer): """A regularizer for the soft constraints in TransH.""" #: The default strategy for optimizing the TransH regularizer's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( weight=dict(type=float, low=0.01, high=1.0, scale='log'), ) def __init__( self, weight: float = 0.05, epsilon: float = 1e-5, parameters: Optional[Iterable[nn.Parameter]] = None, ): # The regularization in TransH enforces the defined soft constraints that should computed only for every batch. # Therefore, apply_only_once is always set to True. super().__init__(weight=weight, apply_only_once=True, parameters=parameters) self.epsilon = epsilon
[docs] def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102 raise NotImplementedError('TransH regularizer is order-sensitive!')
[docs] def update(self, *tensors: torch.FloatTensor) -> None: # noqa: D102 if len(tensors) != 3: raise KeyError('Expects exactly three tensors') if self.apply_only_once and self.updated: return entity_embeddings, normal_vector_embeddings, relation_embeddings = tensors # Entity soft constraint self.regularization_term += torch.sum(functional.relu(torch.norm(entity_embeddings, dim=-1) ** 2 - 1.0)) # Orthogonality soft constraint d_r_n = functional.normalize(relation_embeddings, dim=-1) self.regularization_term += torch.sum( functional.relu(torch.sum((normal_vector_embeddings * d_r_n) ** 2, dim=-1) - self.epsilon), ) self.updated = True
[docs]class CombinedRegularizer(Regularizer): """A convex combination of regularizers.""" # The normalization factor to balance individual regularizers' contribution. normalization_factor: torch.FloatTensor def __init__( self, regularizers: Iterable[Regularizer], total_weight: float = 1.0, apply_only_once: bool = False, ): super().__init__(weight=total_weight, apply_only_once=apply_only_once) self.regularizers = nn.ModuleList(regularizers) for r in self.regularizers: if isinstance(r, NoRegularizer): raise TypeError('Can not combine a no-op regularizer') self.register_buffer(name='normalization_factor', tensor=torch.as_tensor( sum(r.weight for r in self.regularizers), ).reciprocal()) @property def normalize(self): # noqa: D102 return any(r.normalize for r in self.regularizers)
[docs] def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102 return self.normalization_factor * sum(r.weight * r.forward(x) for r in self.regularizers)
_REGULARIZERS: Collection[Type[Regularizer]] = { NoRegularizer, # type: ignore LpRegularizer, PowerSumRegularizer, CombinedRegularizer, TransHRegularizer, } #: A mapping of regularizers' names to their implementations regularizers: Mapping[str, Type[Regularizer]] = { cls.get_normalized_name(): cls for cls in _REGULARIZERS }
[docs]def get_regularizer_cls(query: Union[None, str, Type[Regularizer]]) -> Type[Regularizer]: """Get the regularizer class.""" return get_cls( query, base=Regularizer, # type: ignore lookup_dict=regularizers, default=NoRegularizer, suffix=_REGULARIZER_SUFFIX, )