"""Image preprocessing and transformation utilities.
Provides functions and classes for resizing, normalizing, and transforming
images into tensors suitable for neural network inference.
"""
from collections.abc import Sequence
from PIL import Image
import torch
import torchvision
_MEAN = (0.485, 0.456, 0.406)
_STD = (0.229, 0.224, 0.225)
[docs]
class ConditionalNormalize(object):
"""Normalize image tensor, handling greyscale by replicating channels.
Converts single-channel (greyscale) images to 3-channel by replicating
the channel, then applies standard ImageNet normalization.
"""
def __init__(self, mean, std):
"""Initialize the normalizer.
Args:
mean: Mean values for each channel for normalization.
std: Standard deviation values for each channel for normalization.
"""
self.normalize = torchvision.transforms.Normalize(mean=mean, std=std)
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
"""Apply channel replication and normalization to tensor.
If the input tensor has a single channel (greyscale), it is converted
to 3 channels by replicating the channel. Then standard normalization
is applied using the configured mean and standard deviation.
Args:
tensor: Input tensor to normalize.
Returns:
Normalized tensor with 3 channels.
"""
if tensor.shape[0] == 1:
tensor = torch.concat((tensor, tensor, tensor), 0)
return self.normalize(tensor)
_TRANSFORM = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(224),
torchvision.transforms.CenterCrop((224, 224)),
torchvision.transforms.ToTensor(),
ConditionalNormalize(_MEAN, _STD),
]
)
[docs]
def save_images_from_iteration(
directory_path: str,
images: Sequence[Image.Image],
run_id: str,
iteration: int,
) -> None:
"""Save images from an iteration to disk in JPEG format.
Args:
directory_path: Directory path to save images to.
images: Sequence of PIL Image objects to save.
run_id: Unique run identifier to include in filenames.
iteration: Iteration number to include in filenames.
"""
for image_number, image in enumerate(images):
filename = f"{run_id}_{iteration}_{image_number}"
save_path = f"{directory_path}/{filename}"
image.save(save_path, format="jpeg")