kunkk commited on
Commit
dad1346
·
verified ·
1 Parent(s): 71d32d7

Upload joint_transforms.py

Browse files
Files changed (1) hide show
  1. joint_transforms.py +72 -0
joint_transforms.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # joint_transforms.py
2
+
3
+ import math
4
+ import random
5
+ import numpy as np
6
+ from PIL import Image, ImageFilter
7
+
8
+ class Compose(object):
9
+ def __init__(self, transforms):
10
+ self.transforms = transforms
11
+
12
+ def __call__(self, img, gt):
13
+ for t in self.transforms:
14
+ img, gt = t(img, gt)
15
+ return img, gt
16
+
17
+ class RandomScaleCrop(object):
18
+ """多尺度缩放裁剪(同时处理图像和标签)"""
19
+ def __init__(self, base_size=352, crop_size=352, scale_factor=[0.75, 1.0, 1.25]):
20
+ self.base_size = base_size
21
+ self.crop_size = crop_size
22
+ self.scale_factor = scale_factor
23
+
24
+ def __call__(self, img, gt):
25
+ # 随机选择缩放比例
26
+ sf = random.choice(self.scale_factor)
27
+ new_size = int(self.base_size * sf)
28
+
29
+ # 缩放
30
+ img = img.resize((new_size, new_size), Image.BILINEAR)
31
+ gt = gt.resize((new_size, new_size), Image.NEAREST)
32
+
33
+ # 随机裁剪
34
+ x = random.randint(0, new_size - self.crop_size)
35
+ y = random.randint(0, new_size - self.crop_size)
36
+ img = img.crop((x, y, x+self.crop_size, y+self.crop_size))
37
+ gt = gt.crop((x, y, x+self.crop_size, y+self.crop_size))
38
+
39
+ return img, gt
40
+
41
+ class RandomRotate(object):
42
+ """随机旋转(保持图像和标签同步)"""
43
+ def __init__(self, degree=30):
44
+ self.degree = degree
45
+
46
+ def __call__(self, img, gt):
47
+ rotate_degree = random.uniform(-self.degree, self.degree)
48
+ img = img.rotate(rotate_degree, Image.BILINEAR)
49
+ gt = gt.rotate(rotate_degree, Image.NEAREST)
50
+ return img, gt
51
+
52
+ class RandomGaussianBlur(object):
53
+ """随机高斯模糊(仅对图像处理)"""
54
+ def __init__(self, p=0.5):
55
+ self.p = p
56
+
57
+ def __call__(self, img, gt):
58
+ if random.random() < self.p:
59
+ img = img.filter(ImageFilter.GaussianBlur(
60
+ radius=random.uniform(0.5, 2.0)))
61
+ return img, gt
62
+
63
+ class RandomHorizontallyFlip(object):
64
+ """随机水平翻转"""
65
+ def __init__(self, p=0.5):
66
+ self.p = p
67
+
68
+ def __call__(self, img, gt):
69
+ if random.random() < self.p:
70
+ img = img.transpose(Image.FLIP_LEFT_RIGHT)
71
+ gt = gt.transpose(Image.FLIP_LEFT_RIGHT)
72
+ return img, gt