import os, glob, random import numpy as np from PIL import Image import torch import torch.utils.data as data import torchvision.transforms as transforms from joint_transforms import Compose, RandomHorizontallyFlip import cv2 class SalObjDataset(data.Dataset): def __init__(self, image_root, gt_root, ek_root, trainsize): self.trainsize = trainsize self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')] self.ek = [ek_root + f for f in os.listdir(gt_root) if f.endswith('.png')] self.images = sorted(self.images) self.gts = sorted(self.gts) self.eks = sorted(self.ek) self.size = len(self.images) self.img_transform = transforms.Compose([ transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) self.gt_transform = transforms.Compose([ transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()]) self.ek_transform = transforms.Compose([ transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()]) def __getitem__(self, index): image = self.rgb_loader(self.images[index]) gt = self.binary_loader(self.gts[index]) ek = self.binary_loader(self.eks[index]) image = self.img_transform(image) gt = self.gt_transform(gt) ek = self.ek_transform(ek) return image, gt, ek def rgb_loader(self, path): with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') def binary_loader(self, path): with open(path, 'rb') as f: img = Image.open(f) return img.convert('L') def __len__(self): return self.size def get_loader(image_root, gt_root, ek_root, batchsize, trainsize, shuffle=True, num_workers=0, pin_memory=True): dataset = SalObjDataset(image_root, gt_root, ek_root, trainsize) data_loader = data.DataLoader(dataset=dataset, batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory) return data_loader class test_dataset: def __init__(self, image_root, gt_root, testsize): self.testsize = testsize self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')] self.images = sorted(self.images) self.gts = sorted(self.gts) self.img_transform = transforms.Compose([ transforms.Resize((self.testsize, self.testsize)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) self.gt_transform = transforms.ToTensor() self.size = len(self.images) self.index = 0 def load_data(self): image = self.rgb_loader(self.images[self.index]) image = self.img_transform(image).unsqueeze(0) gt = self.binary_loader(self.gts[self.index]) name = self.images[self.index].split('/')[-1] if name.endswith('.jpg'): name = name.split('.jpg')[0] + '.png' self.index += 1 return image, gt, name def rgb_loader(self, path): with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') def binary_loader(self, path): with open(path, 'rb') as f: img = Image.open(f) return img.convert('L') def transform_image(image, testsize): """预处理单张图像用于推理 Args: image: PIL Image对象 testsize: 目标尺寸 Returns: torch.Tensor: 预处理后的图像张量 """ transform = transforms.Compose([ transforms.Resize((testsize, testsize)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) return transform(image)