Source code for pykeen.nn.vision.representation

"""
A module of vision related components.

Generally requires :mod:`torchvision` to be installed.
"""

import functools
import pathlib
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union

import torch
import torch.nn
import torch.utils.data
from class_resolver import OptionalKwargs
from docdata import parse_docdata

from .cache import WikidataImageCache
from ..representation import BackfillRepresentation, Representation
from ..utils import ShapeError
from ...datasets import Dataset
from ...triples import TriplesFactory
from ...typing import FloatTensor, LongTensor, OneOrSequence

try:
    from PIL import Image
    from torchvision import models
    from torchvision import transforms as vision_transforms
except ImportError:
    models = vision_transforms = Image = None

__all__ = [
    "VisionDataset",
    "VisualRepresentation",
    "WikidataVisualRepresentation",
    "ImageHint",
    "ImageHints",
]


def _ensure_vision(instance: object, module: Optional[Any]):
    if module is None:
        raise ImportError(f"{instance.__class__.__name__} requires `torchvision` to be installed.")


#: A path to an image file or a tensor representation of the image
ImageHint = Union[str, pathlib.Path, torch.Tensor]
#: A sequence of image hints
ImageHints = Sequence[ImageHint]


[docs] class VisionDataset(torch.utils.data.Dataset): """ A dataset of images with data augmentation. .. note :: requires :mod:`torchvision` to be installed. """ def __init__( self, images: ImageHints, transforms: Optional[Sequence] = None, root: Optional[pathlib.Path] = None, ) -> None: """ Initialize the dataset. :param images: the images, either as (relative) path, or preprocessed tensors. :param transforms: a sequence of transformations to apply to the images, cf. :mod:`torchvision.transforms` :param root: the root directory for images """ _ensure_vision(self, vision_transforms) super().__init__() if root is None: root = pathlib.Path.cwd() self.root = pathlib.Path(root) self.images = images if transforms is None: transforms = [vision_transforms.RandomResizedCrop(size=224), vision_transforms.ToTensor()] transforms = list(transforms) transforms.append(vision_transforms.ConvertImageDtype(torch.get_default_dtype())) self.transforms = vision_transforms.Compose(transforms=transforms) # docstr-coverage: inherited def __getitem__(self, item: int) -> torch.Tensor: # noqa:D105 _ensure_vision(self, Image) image = self.images[item] if isinstance(image, (str, pathlib.Path)): path = pathlib.Path(image) if not path.is_absolute(): path = self.root.joinpath(path) image = Image.open(path) assert isinstance(image, (torch.Tensor, Image.Image)) return self.transforms(image) # docstr-coverage: inherited def __len__(self) -> int: # noqa:D105 return len(self.images)
[docs] @parse_docdata class VisualRepresentation(Representation): """Visual representations using a torchvision model. --- name: Visual """ def __init__( self, images: ImageHints, encoder: Union[str, torch.nn.Module], layer_name: str, max_id: Optional[int] = None, shape: Optional[OneOrSequence[int]] = None, transforms: Optional[Sequence] = None, encoder_kwargs: OptionalKwargs = None, batch_size: int = 32, trainable: bool = True, **kwargs, ): """ Initialize the representations. :param images: the images, either as tensors, or paths to image files. :param encoder: the encoder to use. If given as a string, lookup in :mod:`torchvision.models` :param layer_name: the model's layer name to use for extracting the features, cf. :func:`torchvision.models.feature_extraction.create_feature_extractor` :param max_id: the number of representations. If given, it must match the number of images. :param shape: the shape of an individual representation. If provided, it must match the encoder output dimension :param transforms: transformations to apply to the images. Notice that stochastic transformations will result in stochastic representations, too. :param encoder_kwargs: additional keyword-based parameters passed to encoder upon instantiation. :param batch_size: the batch size to use during encoding :param trainable: whether the encoder should be trainable :param kwargs: additional keyword-based parameters passed to :meth:`Representation.__init__`. :raises ValueError: if `max_id` is provided and does not match the number of images """ _ensure_vision(self, models) self.images = VisionDataset(images=images, transforms=transforms) if isinstance(encoder, str): cls = getattr(models, encoder) encoder = cls(encoder_kwargs or {}) pool = functools.partial(torch.mean, dim=(-1, -2)) encoder = models.feature_extraction.create_feature_extractor( model=encoder, return_nodes={layer_name: "feature"} ) # infer shape with torch.inference_mode(): encoder.eval() shape_ = self._encode(images=self.images[0].unsqueeze(dim=0), encoder=encoder, pool=pool).shape[1:] shape = ShapeError.verify(shape=shape_, reference=shape) if max_id is None: max_id = len(images) elif len(images) != max_id: raise ValueError( f"Inconsistent max_id={max_id} and len(images)={len(images)}. In case there are not images for all " f"IDs, you may consider using BackfillRepresentation.", ) super().__init__(max_id=max_id, shape=shape, **kwargs) self.encoder = encoder self.pool = pool self.batch_size = batch_size or self.max_id self.encoder.train(trainable) self.encoder.requires_grad_(trainable) self.trainable = trainable @staticmethod def _encode( images: FloatTensor, encoder: torch.nn.Module, pool: Callable[[FloatTensor], FloatTensor] ) -> FloatTensor: """ Encode images with the given encoder and pooling methods. :param images: shape: (batch_size, num_channels, height, width) a batch of images :param encoder: the encoder, returning a dictionary with key "features" :param pool: the pooling method to use :return: shape: (batch_size, dim) the encoded representations. """ return pool(encoder(images)["feature"]) # docstr-coverage: inherited def _plain_forward(self, indices: Optional[LongTensor] = None) -> FloatTensor: # noqa: D102 dataset = self.images if indices is not None: dataset = torch.utils.data.Subset(dataset=dataset, indices=indices) data_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size) with torch.inference_mode(mode=not self.trainable): return torch.cat( [self._encode(images=images, encoder=self.encoder, pool=self.pool) for images in data_loader], dim=-1 )
[docs] @parse_docdata class WikidataVisualRepresentation(BackfillRepresentation): """ Visual representations obtained from Wikidata and encoded with a vision encoder. If no image could be found for a certain Wikidata ID, a plain (trainable) embedding will be used instead. Example usage .. code-block:: python from pykeen.datasets import get_dataset from pykeen.models import ERModel from pykeen.nn import WikidataVisualRepresentation from pykeen.pipeline import pipeline dataset = get_dataset(dataset="codexsmall") entity_representations = WikidataVisualRepresentation.from_dataset(dataset=dataset) result = pipeline( dataset=dataset, model=ERModel, model_kwargs=dict( interaction="distmult", entity_representations=entity_representations, relation_representation_kwargs=dict( shape=entity_representations.shape, ), ), ) --- name: Wikidata Visual """ def __init__( self, wikidata_ids: Sequence[str], max_id: Optional[int] = None, image_kwargs: OptionalKwargs = None, **kwargs ): """ Initialize the representation. :param wikidata_ids: the Wikidata IDs :param max_id: the total number of IDs. If provided, must match the length of `wikidata_ids` :param image_kwargs: keyword-based parameters passed to :meth:`WikidataImageCache.get_image_paths` :param kwargs: additional keyword-based parameters passed to :meth:`VisualRepresentation.__init__` :raises ValueError: if the max_id does not match the number of Wikidata IDs """ max_id = max_id or len(wikidata_ids) if len(wikidata_ids) != max_id: raise ValueError(f"Inconsistent max_id={max_id} vs. len(wikidata_ids)={len(wikidata_ids)}") images = WikidataImageCache().get_image_paths(wikidata_ids, **(image_kwargs or {})) base_ids = [i for i, path in enumerate(images) if path is not None] images = [path for path in images if path is not None] super().__init__( max_id=max_id, base_ids=base_ids, base=VisualRepresentation, base_kwargs=dict(images=images, **kwargs) )
[docs] @classmethod def from_triples_factory( cls, triples_factory: TriplesFactory, for_entities: bool = True, **kwargs, ) -> "WikidataVisualRepresentation": """ Prepare a visual representations for Wikidata entities from a triples factory. :param triples_factory: the triples factory :param for_entities: whether to create the initializer for entities (or relations) :param kwargs: additional keyword-based arguments passed to :meth:`WikidataVisualRepresentation.__init__` :returns: a visual representation from the triples factory """ return cls( wikidata_ids=( triples_factory.entity_labeling if for_entities else triples_factory.relation_labeling ).all_labels(), **kwargs, )
[docs] @classmethod def from_dataset( cls, dataset: Dataset, **kwargs, ) -> "WikidataVisualRepresentation": """Prepare representations from a dataset. :param dataset: the dataset; needs to have Wikidata IDs as entity names :param kwargs: additional keyword-based parameters passed to :meth:`WikidataVisualRepresentation.from_triples_factory` :return: the representation :raises TypeError: if the triples factory does not provide labels """ if not isinstance(dataset.training, TriplesFactory): raise TypeError(f"{cls.__name__} requires access to labels, but dataset.training does not provide such.") return cls.from_triples_factory(triples_factory=dataset.training, **kwargs)