File size: 1,120 Bytes
82e3da2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from ts.torch_handler.base_handler import BaseHandler
import torch
import torchvision.transforms as transforms
from PIL import Image
import io

class ImageHandler(BaseHandler):
    def __init__(self):
        super(ImageHandler, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.transform = transforms.Compose([transforms.ToTensor()])

    def preprocess(self, data):
        # TorchServe sends input as bytes → we decode into PIL image
        image_bytes = data[0].get("body")
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        tensor = self.transform(image).unsqueeze(0).to(self.device)
        return tensor

    def inference(self, data, *args, **kwargs):
        with torch.no_grad():
            output = self.model(data)
        return output

    def postprocess(self, data):
        output_tensor = data.squeeze(0).cpu().clamp(0, 1)  # ensure valid range
        output_image = transforms.ToPILImage()(output_tensor)

        buf = io.BytesIO()
        output_image.save(buf, format="PNG")
        return [buf.getvalue()]