Spaces:
Sleeping
Sleeping
| import os, glob, random | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.utils.data as data | |
| import torchvision.transforms as transforms | |
| from joint_transforms import Compose, RandomHorizontallyFlip | |
| import cv2 | |
| class SalObjDataset(data.Dataset): | |
| def __init__(self, image_root, gt_root, ek_root, trainsize): | |
| self.trainsize = trainsize | |
| self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] | |
| self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')] | |
| self.ek = [ek_root + f for f in os.listdir(gt_root) if f.endswith('.png')] | |
| self.images = sorted(self.images) | |
| self.gts = sorted(self.gts) | |
| self.eks = sorted(self.ek) | |
| self.size = len(self.images) | |
| self.img_transform = transforms.Compose([ | |
| transforms.Resize((self.trainsize, self.trainsize)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) | |
| self.gt_transform = transforms.Compose([ | |
| transforms.Resize((self.trainsize, self.trainsize)), | |
| transforms.ToTensor()]) | |
| self.ek_transform = transforms.Compose([ | |
| transforms.Resize((self.trainsize, self.trainsize)), | |
| transforms.ToTensor()]) | |
| def __getitem__(self, index): | |
| image = self.rgb_loader(self.images[index]) | |
| gt = self.binary_loader(self.gts[index]) | |
| ek = self.binary_loader(self.eks[index]) | |
| image = self.img_transform(image) | |
| gt = self.gt_transform(gt) | |
| ek = self.ek_transform(ek) | |
| return image, gt, ek | |
| def rgb_loader(self, path): | |
| with open(path, 'rb') as f: | |
| img = Image.open(f) | |
| return img.convert('RGB') | |
| def binary_loader(self, path): | |
| with open(path, 'rb') as f: | |
| img = Image.open(f) | |
| return img.convert('L') | |
| def __len__(self): | |
| return self.size | |
| def get_loader(image_root, gt_root, ek_root, batchsize, trainsize, shuffle=True, num_workers=0, pin_memory=True): | |
| dataset = SalObjDataset(image_root, gt_root, ek_root, trainsize) | |
| data_loader = data.DataLoader(dataset=dataset, | |
| batch_size=batchsize, | |
| shuffle=shuffle, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory) | |
| return data_loader | |
| class test_dataset: | |
| def __init__(self, image_root, gt_root, testsize): | |
| self.testsize = testsize | |
| self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] | |
| self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') | |
| or f.endswith('.png')] | |
| self.images = sorted(self.images) | |
| self.gts = sorted(self.gts) | |
| self.img_transform = transforms.Compose([ | |
| transforms.Resize((self.testsize, self.testsize)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) | |
| self.gt_transform = transforms.ToTensor() | |
| self.size = len(self.images) | |
| self.index = 0 | |
| def load_data(self): | |
| image = self.rgb_loader(self.images[self.index]) | |
| image = self.img_transform(image).unsqueeze(0) | |
| gt = self.binary_loader(self.gts[self.index]) | |
| name = self.images[self.index].split('/')[-1] | |
| if name.endswith('.jpg'): | |
| name = name.split('.jpg')[0] + '.png' | |
| self.index += 1 | |
| return image, gt, name | |
| def rgb_loader(self, path): | |
| with open(path, 'rb') as f: | |
| img = Image.open(f) | |
| return img.convert('RGB') | |
| def binary_loader(self, path): | |
| with open(path, 'rb') as f: | |
| img = Image.open(f) | |
| return img.convert('L') | |
| def transform_image(image, testsize): | |
| """预处理单张图像用于推理 | |
| Args: | |
| image: PIL Image对象 | |
| testsize: 目标尺寸 | |
| Returns: | |
| torch.Tensor: 预处理后的图像张量 | |
| """ | |
| transform = transforms.Compose([ | |
| transforms.Resize((testsize, testsize)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| return transform(image) |