Pathora / backend /augmentations.py
malavikapradeep2001's picture
Initial Space
bf5da6b
raw
history blame
11.3 kB
# -*- coding: utf-8 -*-
"""
This file is a part of project "Aided-Diagnosis-System-for-Cervical-Cancer-Screening".
See https://github.com/ShenghuaCheng/Aided-Diagnosis-System-for-Cervical-Cancer-Screening for more information.
File name: augmentations.py
Description: augmentation functions.
"""
import functools
import random
import cv2
import numpy as np
from skimage.exposure import adjust_gamma
from loguru import logger
__all__ = [
"Augmentations",
"StylisticTrans",
"SpatialTrans",
]
class Augmentations:
"""
All parameters in each augmentations have been fixed to a suitable range.
img = [size, size, ch]
ch = 3: only img
ch = 4: img with mask at 4th dim
"""
@staticmethod
def Compose(*funcs):
funcs = list(funcs)
func_names = [f.__name__ for f in funcs]
# ensure the norm opt is the last opt
if 'norm' in func_names:
idx = func_names.index('norm')
funcs = funcs[:idx] + funcs[idx:] + [funcs[idx]]
def compose(img: np.ndarray):
return functools.reduce(lambda f, g: lambda x: g(f(x)), funcs)(img)
return compose
"""
# ===========================================================================================================
# random stylistic augmentations
"""
@staticmethod
def RandomGamma(p: float = 0.5):
def random_gamma(img: np.ndarray):
if random.random() < p:
gamma = 0.6 + random.random() * 0.6
img[..., :3] = StylisticTrans.gamma_adjust(img[..., :3], gamma)
return img
return random_gamma
@staticmethod
def RandomSharp(p: float = 0.5):
def random_sharp(img: np.ndarray):
if random.random() < p:
sigma = 8.3 + random.random() * 0.4
img[..., :3] = StylisticTrans.sharp(img[..., :3], sigma)
return img
return random_sharp
@staticmethod
def RandomGaussainBlur(p: float = 0.5):
def random_gaussian_blur(img: np.ndarray):
if random.random() < p:
sigma = 0.1 + random.random() * 1
img[..., :3] = StylisticTrans.gaussian_blur(img[..., :3], sigma)
return img
return random_gaussian_blur
@staticmethod
def RandomHSVDisturb(p: float = 0.5):
def random_hsv_disturb(img: np.ndarray):
if random.random() < p:
k = np.random.random(3) * [0.1, 0.8, 0.45] + [0.95, 0.7, 0.75]
b = np.random.random(3) * [6, 20, 18] + [-3, -10, -10]
img[..., :3] = StylisticTrans.hsv_disturb(img[..., :3], k.tolist(), b.tolist())
return img
return random_hsv_disturb
@staticmethod
def RandomRGBSwitch(p: float = 0.5):
def random_rgb_switch(img: np.ndarray):
if random.random() < p:
bgr_seq = list(range(3))
random.shuffle(bgr_seq)
img[..., :3] = StylisticTrans.bgr_switch(img[..., :3], bgr_seq)
return img
return random_rgb_switch
"""
# ===========================================================================================================
# random spatial augmentations, funcs can be implement to tiles and their masks.
"""
@staticmethod
def RandomRotate90(p: float = 0.5):
def random_rotate90(img: np.ndarray):
if random.random() < p:
angle = 90 * random.randint(1, 3)
img = SpatialTrans.rotate(img, angle)
return img
return random_rotate90
@staticmethod
def RandomHorizontalFlip(p: float = 0.5):
def random_horizontal_flip(img: np.ndarray):
if random.random() < p:
img = SpatialTrans.flip(img, 0)
return img
return random_horizontal_flip
@staticmethod
def RandomVerticalFlip(p: float = 0.5):
def random_vertical_flip(img: np.ndarray):
if random.random() < p:
img = SpatialTrans.flip(img, 1)
return img
return random_vertical_flip
@staticmethod
def RandomScale(p: float = 0.5):
def random_scale(img: np.ndarray):
if random.random() < p:
ratio = 0.8 + random.random() * 0.4
img = SpatialTrans.scale(img, ratio, True)
return img
return random_scale
@staticmethod
def RandomCrop(p: float = 1., size: tuple = (512, 512)):
def random_crop(img: np.ndarray):
if random.random() < p:
# for a large FOV, control the translate range
new_shape = list(img.shape[:2][::-1])
if img.shape[0] > size[1] * 1.5:
new_shape[1] = int(size[1] * 1.5)
if img.shape[1] > size[0] * 1.5:
new_shape[0] = int(size[0] * 1.5)
img = SpatialTrans.center_crop(img.copy(), tuple(new_shape))
# do translate
xy = np.random.random(2) * (np.array(img.shape[:2]) - list(size))
bbox = tuple(xy.astype(np.int).tolist() + list(size))
img = SpatialTrans.crop(img, bbox)
else:
img = SpatialTrans.center_crop(img, size)
return img
return random_crop
@staticmethod
def Normalization(rng: list = [-1, 1]):
def norm(img: np.ndarray):
img = StylisticTrans.normalization(img, rng)
return img
return norm
@staticmethod
def CenterCrop(size: tuple = (512, 512)):
def center_crop(img: np.ndarray):
img = SpatialTrans.center_crop(img, size)
return img
return center_crop
class StylisticTrans:
# TODO Some implementations of augmentation need a efficient way
"""
set of augmentations applied to the content of image
"""
@staticmethod
def gamma_adjust(img: np.ndarray, gamma: float):
""" adjust gamma
:param img: a ndarray, better a BGR
:param gamma: gamma, recommended value 0.6, range [0.6, 1.2]
:return: a ndarray
"""
return adjust_gamma(img.copy(), gamma)
@staticmethod
def sharp(img: np.ndarray, sigma: float):
"""sharp image
:param img: a ndarray, better a BGR
:param sigma: sharp degree, recommended range [8.3, 8.7]
:return: a ndarray
"""
kernel = np.array([[-1, -1, -1], [-1, sigma, -1], [-1, -1, -1]], np.float32) / (sigma - 8) # 锐化
return cv2.filter2D(img.copy(), -1, kernel=kernel)
@staticmethod
def gaussian_blur(img: np.ndarray, sigma: float):
"""blurring image
:param img: a ndarray, better a BGR
:param sigma: blurring degree, recommended range [0.1, 1.1]
:return: a ndarray
"""
return cv2.GaussianBlur(img.copy(), (int(6 * np.ceil(sigma) + 1), int(6 * np.ceil(sigma) + 1)), sigma)
@staticmethod
def hsv_disturb(img: np.ndarray, k: list, b: list):
""" disturb the hsv value
:param img: a BGR ndarray
:param k: low_b = [0.95, 0.7, 0.75] ,upper_b = [1.05, 1.5, 1.2]
:param b: low_b = [-3, -10, -10] ,upper_b = [3, 10, 8]
:return: a BGR ndarray
"""
img = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2HSV)
img = img.astype(np.float)
for ch in range(3):
img[..., ch] = k[ch] * img[..., ch] + b[ch]
img = np.uint8(np.clip(img, np.array([0, 1, 1]), np.array([180, 255, 255])))
return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
@staticmethod
def bgr_switch(img: np.ndarray, bgr_seq: list):
""" switch bgr
:param img: a ndarray, better a BGR
:param bgr_seq: new ch seq
:return: a ndarray
"""
return img.copy()[..., bgr_seq]
@staticmethod
def normalization(img: np.ndarray, rng: list):
"""normalize image according to min and max
:param img: a ndarray
:param rng: normalize image value to range[min, max]
:return: a ndarray
"""
lb, ub = rng
delta = ub - lb
return (img.copy().astype(np.float64) / 255.) * delta + lb#yjx
class SpatialTrans:
"""
set of augmentations applied to the spatial space of image
"""
@staticmethod
def rotate(img: np.ndarray, angle: int):
""" rotate image
# todo Square image and central rotate only, a universal version is needed
:param img: a ndarray
:param angle: rotate angle
:return: a ndarray has same size as input, padding zero or cut out region out of picture
"""
assert img.shape[0] == img.shape[1], "Square image needed."
center = (img.shape[0]/2, img.shape[1]/2)
mat = cv2.getRotationMatrix2D(center, angle, scale=1)
# mat = cv2.getRotationMatrix2D(tuple(np.array(img.shape[:2]) // 2), angle, scale=1)
return cv2.warpAffine(img.copy(), mat, img.shape[:2])
@staticmethod
def flip(img: np.ndarray, flip_axis: int):
"""flip image horizontal or vertical
:param img: a ndarray
:param flip_axis: 0 for horizontal, 1 for vertical
:return: a flipped image
"""
return cv2.flip(img.copy(), flip_axis)
@staticmethod
def scale(img: np.ndarray, ratio: float, fix_size: bool = False):
"""scale image
:param img: a ndarray
:param ratio: scale ratio
:param fix_size: return the center area of scaled image, size of area is same as the image before scaling
:return: a scaled image
"""
shape = img.shape[:2][::-1]
img = cv2.resize(img.copy(), None, fx=ratio, fy=ratio)
if fix_size:
img = SpatialTrans.center_crop(img, shape)
return img
@staticmethod
def crop(img: np.ndarray, bbox: tuple):
"""crop image according to given bbox
:param img: a ndarray
:param bbox: bbox of cropping area (x, y, w, h)
:return: cropped image,padding with zeros
"""
ch = [] if len(img.shape) == 2 else [img.shape[-1]]
template = np.zeros(list(bbox[-2:])[::-1] + ch)
if (bbox[1] >= img.shape[0] or bbox[1] >= img.shape[1]) or (bbox[0] + bbox[2] <= 0 or bbox[1] + bbox[3] <= 0):
logger.warning("Crop area contains nothing, return a zeros array {}".format(template.shape))
return template
foreground = img[
np.maximum(bbox[1], 0): np.minimum(bbox[1] + bbox[3], img.shape[0]),
np.maximum(bbox[0], 0): np.minimum(bbox[0] + bbox[2], img.shape[1]), :]
template[
np.maximum(-bbox[1], 0): np.minimum(-bbox[1] + img.shape[0], bbox[3]),
np.maximum(-bbox[0], 0): np.minimum(-bbox[0] + img.shape[1], bbox[2]), :] = foreground
return template.astype(np.uint8)
@staticmethod
def center_crop(img: np.ndarray, shape: tuple):
"""return the center area in shape
:param img: a ndarray
:param shape: center crop shape (w, h)
:return:
"""
center = np.array(img.shape[:2]) // 2
init = center[::-1] - np.array(shape) // 2
bbox = tuple(init.tolist() + list(shape))
return SpatialTrans.crop(img, bbox)