Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """ | |
| This file is a part of project "Aided-Diagnosis-System-for-Cervical-Cancer-Screening". | |
| See https://github.com/ShenghuaCheng/Aided-Diagnosis-System-for-Cervical-Cancer-Screening for more information. | |
| File name: augmentations.py | |
| Description: augmentation functions. | |
| """ | |
| import functools | |
| import random | |
| import cv2 | |
| import numpy as np | |
| from skimage.exposure import adjust_gamma | |
| from loguru import logger | |
| __all__ = [ | |
| "Augmentations", | |
| "StylisticTrans", | |
| "SpatialTrans", | |
| ] | |
| class Augmentations: | |
| """ | |
| All parameters in each augmentations have been fixed to a suitable range. | |
| img = [size, size, ch] | |
| ch = 3: only img | |
| ch = 4: img with mask at 4th dim | |
| """ | |
| def Compose(*funcs): | |
| funcs = list(funcs) | |
| func_names = [f.__name__ for f in funcs] | |
| # ensure the norm opt is the last opt | |
| if 'norm' in func_names: | |
| idx = func_names.index('norm') | |
| funcs = funcs[:idx] + funcs[idx:] + [funcs[idx]] | |
| def compose(img: np.ndarray): | |
| return functools.reduce(lambda f, g: lambda x: g(f(x)), funcs)(img) | |
| return compose | |
| """ | |
| # =========================================================================================================== | |
| # random stylistic augmentations | |
| """ | |
| def RandomGamma(p: float = 0.5): | |
| def random_gamma(img: np.ndarray): | |
| if random.random() < p: | |
| gamma = 0.6 + random.random() * 0.6 | |
| img[..., :3] = StylisticTrans.gamma_adjust(img[..., :3], gamma) | |
| return img | |
| return random_gamma | |
| def RandomSharp(p: float = 0.5): | |
| def random_sharp(img: np.ndarray): | |
| if random.random() < p: | |
| sigma = 8.3 + random.random() * 0.4 | |
| img[..., :3] = StylisticTrans.sharp(img[..., :3], sigma) | |
| return img | |
| return random_sharp | |
| def RandomGaussainBlur(p: float = 0.5): | |
| def random_gaussian_blur(img: np.ndarray): | |
| if random.random() < p: | |
| sigma = 0.1 + random.random() * 1 | |
| img[..., :3] = StylisticTrans.gaussian_blur(img[..., :3], sigma) | |
| return img | |
| return random_gaussian_blur | |
| def RandomHSVDisturb(p: float = 0.5): | |
| def random_hsv_disturb(img: np.ndarray): | |
| if random.random() < p: | |
| k = np.random.random(3) * [0.1, 0.8, 0.45] + [0.95, 0.7, 0.75] | |
| b = np.random.random(3) * [6, 20, 18] + [-3, -10, -10] | |
| img[..., :3] = StylisticTrans.hsv_disturb(img[..., :3], k.tolist(), b.tolist()) | |
| return img | |
| return random_hsv_disturb | |
| def RandomRGBSwitch(p: float = 0.5): | |
| def random_rgb_switch(img: np.ndarray): | |
| if random.random() < p: | |
| bgr_seq = list(range(3)) | |
| random.shuffle(bgr_seq) | |
| img[..., :3] = StylisticTrans.bgr_switch(img[..., :3], bgr_seq) | |
| return img | |
| return random_rgb_switch | |
| """ | |
| # =========================================================================================================== | |
| # random spatial augmentations, funcs can be implement to tiles and their masks. | |
| """ | |
| def RandomRotate90(p: float = 0.5): | |
| def random_rotate90(img: np.ndarray): | |
| if random.random() < p: | |
| angle = 90 * random.randint(1, 3) | |
| img = SpatialTrans.rotate(img, angle) | |
| return img | |
| return random_rotate90 | |
| def RandomHorizontalFlip(p: float = 0.5): | |
| def random_horizontal_flip(img: np.ndarray): | |
| if random.random() < p: | |
| img = SpatialTrans.flip(img, 0) | |
| return img | |
| return random_horizontal_flip | |
| def RandomVerticalFlip(p: float = 0.5): | |
| def random_vertical_flip(img: np.ndarray): | |
| if random.random() < p: | |
| img = SpatialTrans.flip(img, 1) | |
| return img | |
| return random_vertical_flip | |
| def RandomScale(p: float = 0.5): | |
| def random_scale(img: np.ndarray): | |
| if random.random() < p: | |
| ratio = 0.8 + random.random() * 0.4 | |
| img = SpatialTrans.scale(img, ratio, True) | |
| return img | |
| return random_scale | |
| def RandomCrop(p: float = 1., size: tuple = (512, 512)): | |
| def random_crop(img: np.ndarray): | |
| if random.random() < p: | |
| # for a large FOV, control the translate range | |
| new_shape = list(img.shape[:2][::-1]) | |
| if img.shape[0] > size[1] * 1.5: | |
| new_shape[1] = int(size[1] * 1.5) | |
| if img.shape[1] > size[0] * 1.5: | |
| new_shape[0] = int(size[0] * 1.5) | |
| img = SpatialTrans.center_crop(img.copy(), tuple(new_shape)) | |
| # do translate | |
| xy = np.random.random(2) * (np.array(img.shape[:2]) - list(size)) | |
| bbox = tuple(xy.astype(np.int).tolist() + list(size)) | |
| img = SpatialTrans.crop(img, bbox) | |
| else: | |
| img = SpatialTrans.center_crop(img, size) | |
| return img | |
| return random_crop | |
| def Normalization(rng: list = [-1, 1]): | |
| def norm(img: np.ndarray): | |
| img = StylisticTrans.normalization(img, rng) | |
| return img | |
| return norm | |
| def CenterCrop(size: tuple = (512, 512)): | |
| def center_crop(img: np.ndarray): | |
| img = SpatialTrans.center_crop(img, size) | |
| return img | |
| return center_crop | |
| class StylisticTrans: | |
| # TODO Some implementations of augmentation need a efficient way | |
| """ | |
| set of augmentations applied to the content of image | |
| """ | |
| def gamma_adjust(img: np.ndarray, gamma: float): | |
| """ adjust gamma | |
| :param img: a ndarray, better a BGR | |
| :param gamma: gamma, recommended value 0.6, range [0.6, 1.2] | |
| :return: a ndarray | |
| """ | |
| return adjust_gamma(img.copy(), gamma) | |
| def sharp(img: np.ndarray, sigma: float): | |
| """sharp image | |
| :param img: a ndarray, better a BGR | |
| :param sigma: sharp degree, recommended range [8.3, 8.7] | |
| :return: a ndarray | |
| """ | |
| kernel = np.array([[-1, -1, -1], [-1, sigma, -1], [-1, -1, -1]], np.float32) / (sigma - 8) # 锐化 | |
| return cv2.filter2D(img.copy(), -1, kernel=kernel) | |
| def gaussian_blur(img: np.ndarray, sigma: float): | |
| """blurring image | |
| :param img: a ndarray, better a BGR | |
| :param sigma: blurring degree, recommended range [0.1, 1.1] | |
| :return: a ndarray | |
| """ | |
| return cv2.GaussianBlur(img.copy(), (int(6 * np.ceil(sigma) + 1), int(6 * np.ceil(sigma) + 1)), sigma) | |
| def hsv_disturb(img: np.ndarray, k: list, b: list): | |
| """ disturb the hsv value | |
| :param img: a BGR ndarray | |
| :param k: low_b = [0.95, 0.7, 0.75] ,upper_b = [1.05, 1.5, 1.2] | |
| :param b: low_b = [-3, -10, -10] ,upper_b = [3, 10, 8] | |
| :return: a BGR ndarray | |
| """ | |
| img = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2HSV) | |
| img = img.astype(np.float) | |
| for ch in range(3): | |
| img[..., ch] = k[ch] * img[..., ch] + b[ch] | |
| img = np.uint8(np.clip(img, np.array([0, 1, 1]), np.array([180, 255, 255]))) | |
| return cv2.cvtColor(img, cv2.COLOR_HSV2BGR) | |
| def bgr_switch(img: np.ndarray, bgr_seq: list): | |
| """ switch bgr | |
| :param img: a ndarray, better a BGR | |
| :param bgr_seq: new ch seq | |
| :return: a ndarray | |
| """ | |
| return img.copy()[..., bgr_seq] | |
| def normalization(img: np.ndarray, rng: list): | |
| """normalize image according to min and max | |
| :param img: a ndarray | |
| :param rng: normalize image value to range[min, max] | |
| :return: a ndarray | |
| """ | |
| lb, ub = rng | |
| delta = ub - lb | |
| return (img.copy().astype(np.float64) / 255.) * delta + lb#yjx | |
| class SpatialTrans: | |
| """ | |
| set of augmentations applied to the spatial space of image | |
| """ | |
| def rotate(img: np.ndarray, angle: int): | |
| """ rotate image | |
| # todo Square image and central rotate only, a universal version is needed | |
| :param img: a ndarray | |
| :param angle: rotate angle | |
| :return: a ndarray has same size as input, padding zero or cut out region out of picture | |
| """ | |
| assert img.shape[0] == img.shape[1], "Square image needed." | |
| center = (img.shape[0]/2, img.shape[1]/2) | |
| mat = cv2.getRotationMatrix2D(center, angle, scale=1) | |
| # mat = cv2.getRotationMatrix2D(tuple(np.array(img.shape[:2]) // 2), angle, scale=1) | |
| return cv2.warpAffine(img.copy(), mat, img.shape[:2]) | |
| def flip(img: np.ndarray, flip_axis: int): | |
| """flip image horizontal or vertical | |
| :param img: a ndarray | |
| :param flip_axis: 0 for horizontal, 1 for vertical | |
| :return: a flipped image | |
| """ | |
| return cv2.flip(img.copy(), flip_axis) | |
| def scale(img: np.ndarray, ratio: float, fix_size: bool = False): | |
| """scale image | |
| :param img: a ndarray | |
| :param ratio: scale ratio | |
| :param fix_size: return the center area of scaled image, size of area is same as the image before scaling | |
| :return: a scaled image | |
| """ | |
| shape = img.shape[:2][::-1] | |
| img = cv2.resize(img.copy(), None, fx=ratio, fy=ratio) | |
| if fix_size: | |
| img = SpatialTrans.center_crop(img, shape) | |
| return img | |
| def crop(img: np.ndarray, bbox: tuple): | |
| """crop image according to given bbox | |
| :param img: a ndarray | |
| :param bbox: bbox of cropping area (x, y, w, h) | |
| :return: cropped image,padding with zeros | |
| """ | |
| ch = [] if len(img.shape) == 2 else [img.shape[-1]] | |
| template = np.zeros(list(bbox[-2:])[::-1] + ch) | |
| if (bbox[1] >= img.shape[0] or bbox[1] >= img.shape[1]) or (bbox[0] + bbox[2] <= 0 or bbox[1] + bbox[3] <= 0): | |
| logger.warning("Crop area contains nothing, return a zeros array {}".format(template.shape)) | |
| return template | |
| foreground = img[ | |
| np.maximum(bbox[1], 0): np.minimum(bbox[1] + bbox[3], img.shape[0]), | |
| np.maximum(bbox[0], 0): np.minimum(bbox[0] + bbox[2], img.shape[1]), :] | |
| template[ | |
| np.maximum(-bbox[1], 0): np.minimum(-bbox[1] + img.shape[0], bbox[3]), | |
| np.maximum(-bbox[0], 0): np.minimum(-bbox[0] + img.shape[1], bbox[2]), :] = foreground | |
| return template.astype(np.uint8) | |
| def center_crop(img: np.ndarray, shape: tuple): | |
| """return the center area in shape | |
| :param img: a ndarray | |
| :param shape: center crop shape (w, h) | |
| :return: | |
| """ | |
| center = np.array(img.shape[:2]) // 2 | |
| init = center[::-1] - np.array(shape) // 2 | |
| bbox = tuple(init.tolist() + list(shape)) | |
| return SpatialTrans.crop(img, bbox) |