kunkk commited on
Commit
385ef18
·
verified ·
1 Parent(s): ff92051

Upload data.py

Browse files
Files changed (1) hide show
  1. utils1/data.py +125 -0
utils1/data.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob, random
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+ import torch
6
+ import torch.utils.data as data
7
+ import torchvision.transforms as transforms
8
+ from joint_transforms import Compose, RandomHorizontallyFlip
9
+
10
+ import cv2
11
+
12
+
13
+ class SalObjDataset(data.Dataset):
14
+ def __init__(self, image_root, gt_root, ek_root, trainsize):
15
+ self.trainsize = trainsize
16
+ self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
17
+ self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')]
18
+ self.ek = [ek_root + f for f in os.listdir(gt_root) if f.endswith('.png')]
19
+
20
+ self.images = sorted(self.images)
21
+ self.gts = sorted(self.gts)
22
+ self.eks = sorted(self.ek)
23
+
24
+ self.size = len(self.images)
25
+ self.img_transform = transforms.Compose([
26
+ transforms.Resize((self.trainsize, self.trainsize)),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
29
+ self.gt_transform = transforms.Compose([
30
+ transforms.Resize((self.trainsize, self.trainsize)),
31
+ transforms.ToTensor()])
32
+ self.ek_transform = transforms.Compose([
33
+ transforms.Resize((self.trainsize, self.trainsize)),
34
+ transforms.ToTensor()])
35
+
36
+
37
+ def __getitem__(self, index):
38
+ image = self.rgb_loader(self.images[index])
39
+ gt = self.binary_loader(self.gts[index])
40
+ ek = self.binary_loader(self.eks[index])
41
+
42
+ image = self.img_transform(image)
43
+ gt = self.gt_transform(gt)
44
+ ek = self.ek_transform(ek)
45
+
46
+ return image, gt, ek
47
+
48
+ def rgb_loader(self, path):
49
+ with open(path, 'rb') as f:
50
+ img = Image.open(f)
51
+ return img.convert('RGB')
52
+
53
+ def binary_loader(self, path):
54
+ with open(path, 'rb') as f:
55
+ img = Image.open(f)
56
+ return img.convert('L')
57
+
58
+ def __len__(self):
59
+ return self.size
60
+
61
+
62
+ def get_loader(image_root, gt_root, ek_root, batchsize, trainsize, shuffle=True, num_workers=0, pin_memory=True):
63
+ dataset = SalObjDataset(image_root, gt_root, ek_root, trainsize)
64
+ data_loader = data.DataLoader(dataset=dataset,
65
+ batch_size=batchsize,
66
+ shuffle=shuffle,
67
+ num_workers=num_workers,
68
+ pin_memory=pin_memory)
69
+ return data_loader
70
+
71
+
72
+ class test_dataset:
73
+ def __init__(self, image_root, gt_root, testsize):
74
+ self.testsize = testsize
75
+ self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
76
+ self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
77
+ or f.endswith('.png')]
78
+ self.images = sorted(self.images)
79
+ self.gts = sorted(self.gts)
80
+ self.img_transform = transforms.Compose([
81
+ transforms.Resize((self.testsize, self.testsize)),
82
+ transforms.ToTensor(),
83
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
84
+ self.gt_transform = transforms.ToTensor()
85
+ self.size = len(self.images)
86
+ self.index = 0
87
+
88
+ def load_data(self):
89
+ image = self.rgb_loader(self.images[self.index])
90
+ image = self.img_transform(image).unsqueeze(0)
91
+ gt = self.binary_loader(self.gts[self.index])
92
+ name = self.images[self.index].split('/')[-1]
93
+ if name.endswith('.jpg'):
94
+ name = name.split('.jpg')[0] + '.png'
95
+ self.index += 1
96
+ return image, gt, name
97
+
98
+ def rgb_loader(self, path):
99
+ with open(path, 'rb') as f:
100
+ img = Image.open(f)
101
+ return img.convert('RGB')
102
+
103
+ def binary_loader(self, path):
104
+ with open(path, 'rb') as f:
105
+ img = Image.open(f)
106
+ return img.convert('L')
107
+
108
+
109
+ def transform_image(image, testsize):
110
+ """预处理单张图像用于推理
111
+
112
+ Args:
113
+ image: PIL Image对象
114
+ testsize: 目标尺寸
115
+
116
+ Returns:
117
+ torch.Tensor: 预处理后的图像张量
118
+ """
119
+ transform = transforms.Compose([
120
+ transforms.Resize((testsize, testsize)),
121
+ transforms.ToTensor(),
122
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
123
+ ])
124
+
125
+ return transform(image)