from ts.torch_handler.base_handler import BaseHandler import torch import torchvision.transforms as transforms from PIL import Image import io import time import logging import torch.nn.functional as F import numpy as np from skimage.metrics import peak_signal_noise_ratio, structural_similarity logger = logging.getLogger(__name__) class ImageHandler(BaseHandler): def __init__(self): super(ImageHandler, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.transform = transforms.Compose([transforms.ToTensor()]) self.input_tensor_for_metrics = None self.start_time = 0 def preprocess(self, data): self.start_time = time.time() image_bytes = data[0].get("body") image = Image.open(io.BytesIO(image_bytes)).convert("RGB") width, height = image.size logger.info(f"DATA_QUALITY: resolution={width}x{height}, format={image.format}") tensor = self.transform(image).unsqueeze(0).to(self.device) self.input_tensor_for_metrics = tensor.clone().detach() return tensor def inference(self, data, *args, **kwargs): with torch.no_grad(): output = self.model(data) return output def postprocess(self, data): output_batched = data input_batched = self.input_tensor_for_metrics output_tensor = output_batched.squeeze(0).cpu().clamp(0, 1) input_tensor = input_batched.squeeze(0).cpu() output_tensor_resized = output_tensor if output_tensor.shape != input_tensor.shape: output_tensor_resized = F.interpolate( output_tensor.unsqueeze(0), size=input_tensor.shape[-2:], mode='bilinear', align_corners=False ).squeeze(0) pixel_difference = torch.mean(torch.abs(input_tensor - output_tensor_resized)).item() logger.info(f"OUTPUT_QUALITY: denoising_intensity={pixel_difference:.4f}") end_time = time.time() latency_ms = (end_time - self.start_time) * 1000 logger.info(f"OPERATIONAL_HEALTH: total_latency={latency_ms:.2f}ms") output_image = transforms.ToPILImage()(output_tensor) buf = io.BytesIO() output_image.save(buf, format="PNG") return [buf.getvalue()]