Source code for pykeen.nn.vision.representation

"""
A module of vision related components.

Generally requires :mod:`torchvision` to be installed.
"""

import functools
import pathlib
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union

import torch
import torch.nn
import torch.utils.data
from class_resolver import OptionalKwargs
from docdata import parse_docdata
from typing_extensions import Self

from .cache import WikidataImageCache
from ..representation import BackfillRepresentation, Representation
from ..utils import ShapeError
from ...datasets import Dataset
from ...triples import TriplesFactory
from ...typing import FloatTensor, LongTensor, OneOrSequence

try:
    from PIL import Image
    from torchvision import models
    from torchvision import transforms as vision_transforms
except ImportError:
    models = vision_transforms = Image = None

__all__ = [
    "VisionDataset",
    "VisualRepresentation",
    "WikidataVisualRepresentation",
    "ImageHint",
    "ImageHints",
]


def _ensure_vision(instance: object, module: Optional[Any]):
    if module is None:
        raise ImportError(f"{instance.__class__.__name__} requires `torchvision` to be installed.")


#: A path to an image file or a tensor representation of the image
ImageHint = Union[str, pathlib.Path, torch.Tensor]
#: A sequence of image hints
ImageHints = Sequence[ImageHint]


[docs] class VisionDataset(torch.utils.data.Dataset): """ A dataset of images with data augmentation. .. note :: requires :mod:`torchvision` to be installed. """ def __init__( self, images: ImageHints, transforms: Optional[Sequence] = None, root: Optional[pathlib.Path] = None, ) -> None: """ Initialize the dataset. :param images: The images, either as (relative) path, or preprocessed tensors. :param transforms: A sequence of transformations to apply to the images, cf. :mod:`torchvision.transforms`. Defaults to random size crops. :param root: The root directory for images. """ _ensure_vision(self, vision_transforms) super().__init__() if root is None: root = pathlib.Path.cwd() self.root = pathlib.Path(root) self.images = images if transforms is None: transforms = [vision_transforms.RandomResizedCrop(size=224), vision_transforms.ToTensor()] transforms = list(transforms) transforms.append(vision_transforms.ConvertImageDtype(torch.get_default_dtype())) self.transforms = vision_transforms.Compose(transforms=transforms) # docstr-coverage: inherited def __getitem__(self, item: int) -> torch.Tensor: # noqa:D105 _ensure_vision(self, Image) image = self.images[item] if isinstance(image, (str, pathlib.Path)): path = pathlib.Path(image) if not path.is_absolute(): path = self.root.joinpath(path) image = Image.open(path) assert isinstance(image, (torch.Tensor, Image.Image)) return self.transforms(image) # docstr-coverage: inherited def __len__(self) -> int: # noqa:D105 return len(self.images)
[docs] @parse_docdata class VisualRepresentation(Representation): """Visual representations using a :mod:`torchvision` model. --- name: Visual """ def __init__( self, images: ImageHints, encoder: Union[str, torch.nn.Module], layer_name: str, max_id: Optional[int] = None, shape: Optional[OneOrSequence[int]] = None, transforms: Optional[Sequence] = None, encoder_kwargs: OptionalKwargs = None, batch_size: int = 32, trainable: bool = True, **kwargs, ): """ Initialize the representations. :param images: The images, either as tensors, or paths to image files. :param encoder: The encoder to use. If given as a string, lookup in :mod:`torchvision.models`. :param layer_name: The model's layer name to use for extracting the features, cf. :func:`torchvision.models.feature_extraction.create_feature_extractor` :param max_id: The number of representations. If given, it must match the number of images. :param shape: The shape of an individual representation. If provided, it must match the encoder output dimension :param transforms: Transformations to apply to the images. Notice that stochastic transformations will result in stochastic representations, too. :param encoder_kwargs: Additional keyword-based parameters passed to encoder upon instantiation. :param batch_size: The batch size to use during encoding. :param trainable: Whether the encoder should be trainable. :param kwargs: Additional keyword-based parameters passed to :class:`~pykeen.nn.representation.Representation`. :raises ValueError: If `max_id` is provided and does not match the number of images. """ _ensure_vision(self, models) self.images = VisionDataset(images=images, transforms=transforms) if isinstance(encoder, str): cls = getattr(models, encoder) encoder = cls(encoder_kwargs or {}) pool = functools.partial(torch.mean, dim=(-1, -2)) encoder = models.feature_extraction.create_feature_extractor( model=encoder, return_nodes={layer_name: "feature"} ) # infer shape with torch.inference_mode(): encoder.eval() shape_ = self._encode(images=self.images[0].unsqueeze(dim=0), encoder=encoder, pool=pool).shape[1:] shape = ShapeError.verify(shape=shape_, reference=shape) if max_id is None: max_id = len(images) elif len(images) != max_id: raise ValueError( f"Inconsistent max_id={max_id} and len(images)={len(images)}. In case there are not images for all " f"IDs, you may consider using BackfillRepresentation.", ) super().__init__(max_id=max_id, shape=shape, **kwargs) self.encoder = encoder self.pool = pool self.batch_size = batch_size or self.max_id self.encoder.train(trainable) self.encoder.requires_grad_(trainable) self.trainable = trainable @staticmethod def _encode( images: FloatTensor, encoder: torch.nn.Module, pool: Callable[[FloatTensor], FloatTensor] ) -> FloatTensor: """ Encode images with the given encoder and pooling methods. :param images: shape: ``(batch_size, num_channels, height, width)`` A batch of images. :param encoder: The encoder, returning a dictionary with key "features". :param pool: The pooling method to use. :return: shape: (batch_size, dim) The encoded representations. """ return pool(encoder(images)["feature"]) # docstr-coverage: inherited def _plain_forward(self, indices: Optional[LongTensor] = None) -> FloatTensor: # noqa: D102 dataset = self.images if indices is not None: dataset = torch.utils.data.Subset(dataset=dataset, indices=indices) data_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size) # TODO: automatic batch size optimization? with torch.inference_mode(mode=not self.trainable): return torch.cat( [self._encode(images=images, encoder=self.encoder, pool=self.pool) for images in data_loader], dim=-1 )
[docs] @parse_docdata class WikidataVisualRepresentation(BackfillRepresentation): """ Visual representations obtained from Wikidata and encoded with a vision encoder. If no image could be found for a certain Wikidata ID, a plain (trainable) embedding will be used instead. Example usage .. literalinclude:: ../examples/nn/representation/visual_wikidata.py --- name: Wikidata Visual """ def __init__( self, wikidata_ids: Sequence[str], max_id: Optional[int] = None, image_kwargs: OptionalKwargs = None, **kwargs ): """ Initialize the representation. :param wikidata_ids: The Wikidata IDs. :param max_id: The total number of IDs. If provided, must match the length of ``wikidata_ids``. :param image_kwargs: Keyword-based parameters passed to :meth:`pykeen.nn.vision.cache.WikidataImageCache.get_image_paths`. :param kwargs: Additional keyword-based parameters passed to :class:`pykeen.nn.vision.representation.VisualRepresentation`. :raises ValueError: If the max_id does not match the number of Wikidata IDs. """ max_id = max_id or len(wikidata_ids) if len(wikidata_ids) != max_id: raise ValueError(f"Inconsistent max_id={max_id} vs. len(wikidata_ids)={len(wikidata_ids)}") images = WikidataImageCache().get_image_paths(wikidata_ids, **(image_kwargs or {})) base_ids = [i for i, path in enumerate(images) if path is not None] images = [path for path in images if path is not None] super().__init__( max_id=max_id, base_ids=base_ids, base=VisualRepresentation, base_kwargs=dict(images=images, **kwargs) )
[docs] @classmethod def from_triples_factory( cls, triples_factory: TriplesFactory, for_entities: bool = True, **kwargs, ) -> Self: """ Prepare a visual representations for Wikidata entities from a triples factory. :param triples_factory: The triples factory. :param for_entities: Whether to create the initializer for entities (or relations). :param kwargs: Additional keyword-based arguments passed to :class:`pykeen.nn.vision.representation.WikidataVisualRepresentation`. :returns: A visual representation from the triples factory. """ return cls( wikidata_ids=( triples_factory.entity_labeling if for_entities else triples_factory.relation_labeling ).all_labels(), **kwargs, )
[docs] @classmethod def from_dataset( cls, dataset: Dataset, for_entities: bool = True, **kwargs, ) -> Self: """Prepare representations from a dataset. :param dataset: The dataset; needs to have Wikidata IDs as entity names. :param for_entities: Whether to create the initializer for entities (or relations). :param kwargs: Additional keyword-based arguments passed to :class:`pykeen.nn.vision.representation.WikidataVisualRepresentation`. :return: A visual representation from the training factory in the dataset. :raises TypeError: If the triples factory does not provide labels. """ if not isinstance(dataset.training, TriplesFactory): raise TypeError(f"{cls.__name__} requires access to labels, but dataset.training does not provide such.") return cls.from_triples_factory(triples_factory=dataset.training, for_entities=for_entities, **kwargs)