kunkk's picture
Upload data.py
385ef18 verified
raw
history blame
4.33 kB
import os, glob, random
import numpy as np
from PIL import Image
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from joint_transforms import Compose, RandomHorizontallyFlip
import cv2
class SalObjDataset(data.Dataset):
def __init__(self, image_root, gt_root, ek_root, trainsize):
self.trainsize = trainsize
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')]
self.ek = [ek_root + f for f in os.listdir(gt_root) if f.endswith('.png')]
self.images = sorted(self.images)
self.gts = sorted(self.gts)
self.eks = sorted(self.ek)
self.size = len(self.images)
self.img_transform = transforms.Compose([
transforms.Resize((self.trainsize, self.trainsize)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
self.gt_transform = transforms.Compose([
transforms.Resize((self.trainsize, self.trainsize)),
transforms.ToTensor()])
self.ek_transform = transforms.Compose([
transforms.Resize((self.trainsize, self.trainsize)),
transforms.ToTensor()])
def __getitem__(self, index):
image = self.rgb_loader(self.images[index])
gt = self.binary_loader(self.gts[index])
ek = self.binary_loader(self.eks[index])
image = self.img_transform(image)
gt = self.gt_transform(gt)
ek = self.ek_transform(ek)
return image, gt, ek
def rgb_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def binary_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('L')
def __len__(self):
return self.size
def get_loader(image_root, gt_root, ek_root, batchsize, trainsize, shuffle=True, num_workers=0, pin_memory=True):
dataset = SalObjDataset(image_root, gt_root, ek_root, trainsize)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batchsize,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory)
return data_loader
class test_dataset:
def __init__(self, image_root, gt_root, testsize):
self.testsize = testsize
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
or f.endswith('.png')]
self.images = sorted(self.images)
self.gts = sorted(self.gts)
self.img_transform = transforms.Compose([
transforms.Resize((self.testsize, self.testsize)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
self.gt_transform = transforms.ToTensor()
self.size = len(self.images)
self.index = 0
def load_data(self):
image = self.rgb_loader(self.images[self.index])
image = self.img_transform(image).unsqueeze(0)
gt = self.binary_loader(self.gts[self.index])
name = self.images[self.index].split('/')[-1]
if name.endswith('.jpg'):
name = name.split('.jpg')[0] + '.png'
self.index += 1
return image, gt, name
def rgb_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def binary_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('L')
def transform_image(image, testsize):
"""预处理单张图像用于推理
Args:
image: PIL Image对象
testsize: 目标尺寸
Returns:
torch.Tensor: 预处理后的图像张量
"""
transform = transforms.Compose([
transforms.Resize((testsize, testsize)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return transform(image)