Spaces:
Configuration error
Configuration error
| import os | |
| import random | |
| import pickle | |
| from pathlib import Path | |
| from itertools import repeat | |
| from multiprocessing.pool import Pool, ThreadPool | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader, distributed | |
| from tqdm import tqdm | |
| from ..augmentations import augment_hsv | |
| from ..dataloaders import InfiniteDataLoader, LoadImagesAndLabels, seed_worker, get_hash, verify_image_label, HELP_URL, TQDM_BAR_FORMAT, LOCAL_RANK | |
| from ..general import NUM_THREADS, LOGGER, xyn2xy, xywhn2xyxy, xyxy2xywhn | |
| from ..torch_utils import torch_distributed_zero_first | |
| from ..coco_utils import annToMask, getCocoIds | |
| from .augmentations import mixup, random_perspective, copy_paste, letterbox | |
| RANK = int(os.getenv('RANK', -1)) | |
| def create_dataloader(path, | |
| imgsz, | |
| batch_size, | |
| stride, | |
| single_cls=False, | |
| hyp=None, | |
| augment=False, | |
| cache=False, | |
| pad=0.0, | |
| rect=False, | |
| rank=-1, | |
| workers=8, | |
| image_weights=False, | |
| close_mosaic=False, | |
| quad=False, | |
| prefix='', | |
| shuffle=False, | |
| mask_downsample_ratio=1, | |
| overlap_mask=False): | |
| if rect and shuffle: | |
| LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False') | |
| shuffle = False | |
| with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP | |
| dataset = LoadImagesAndLabelsAndMasks( | |
| path, | |
| imgsz, | |
| batch_size, | |
| augment=augment, # augmentation | |
| hyp=hyp, # hyperparameters | |
| rect=rect, # rectangular batches | |
| cache_images=cache, | |
| single_cls=single_cls, | |
| stride=int(stride), | |
| pad=pad, | |
| image_weights=image_weights, | |
| prefix=prefix, | |
| downsample_ratio=mask_downsample_ratio, | |
| overlap=overlap_mask) | |
| batch_size = min(batch_size, len(dataset)) | |
| nd = torch.cuda.device_count() # number of CUDA devices | |
| nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers | |
| sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) | |
| #loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates | |
| loader = DataLoader if image_weights or close_mosaic else InfiniteDataLoader | |
| generator = torch.Generator() | |
| generator.manual_seed(6148914691236517205 + RANK) | |
| return loader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle and sampler is None, | |
| num_workers=nw, | |
| sampler=sampler, | |
| pin_memory=True, | |
| collate_fn=LoadImagesAndLabelsAndMasks.collate_fn4 if quad else LoadImagesAndLabelsAndMasks.collate_fn, | |
| worker_init_fn=seed_worker, | |
| generator=generator, | |
| ), dataset | |
| def img2stuff_paths(img_paths): | |
| # Define label paths as a function of image paths | |
| sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}stuff{os.sep}' # /images/, /segmentations/ substrings | |
| return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths] | |
| class LoadImagesAndLabelsAndMasks(LoadImagesAndLabels): # for training/testing | |
| def __init__( | |
| self, | |
| path, | |
| img_size=640, | |
| batch_size=16, | |
| augment=False, | |
| hyp=None, | |
| rect=False, | |
| image_weights=False, | |
| cache_images=False, | |
| single_cls=False, | |
| stride=32, | |
| pad=0, | |
| min_items=0, | |
| prefix="", | |
| downsample_ratio=1, | |
| overlap=False, | |
| ): | |
| super().__init__( | |
| path, | |
| img_size, | |
| batch_size, | |
| augment, | |
| hyp, | |
| rect, | |
| image_weights, | |
| cache_images, | |
| single_cls, | |
| stride, | |
| pad, | |
| min_items, | |
| prefix) | |
| self.downsample_ratio = downsample_ratio | |
| self.overlap = overlap | |
| # semantic segmentation | |
| self.coco_ids = getCocoIds() | |
| # Check cache | |
| self.seg_files = img2stuff_paths(self.im_files) # labels | |
| p = Path(path) | |
| cache_path = (p.with_suffix('') if p.is_file() else Path(self.seg_files[0]).parent) | |
| cache_path = Path(str(cache_path) + '_stuff').with_suffix('.cache') | |
| try: | |
| cache, exists = np.load(cache_path, allow_pickle = True).item(), True # load dict | |
| #assert cache['version'] == self.cache_version # matches current version | |
| #assert cache['hash'] == get_hash(self.seg_files + self.im_files) # identical hash | |
| except Exception: | |
| cache, exists = self.cache_seg_labels(cache_path, prefix), False # run cache ops | |
| # Display cache | |
| nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total | |
| if exists and LOCAL_RANK in {-1, 0}: | |
| d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt" | |
| tqdm(None, desc = (prefix + d), total = n, initial = n, bar_format = TQDM_BAR_FORMAT) # display cache results | |
| if cache['msgs']: | |
| LOGGER.info('\n'.join(cache['msgs'])) # display warnings | |
| assert (0 < nf) or (not augment), f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}' | |
| # Read cache | |
| [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items | |
| seg_labels, _, self.semantic_masks = zip(*cache.values()) | |
| nl = len(np.concatenate(seg_labels, 0)) # number of labels | |
| assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}' | |
| # Update labels | |
| self.seg_cls = [] | |
| include_class = [] # filter labels to include only these classes (optional) | |
| include_class_array = np.array(include_class).reshape(1, -1) | |
| for i, (label, semantic_masks) in enumerate(zip(seg_labels, self.semantic_masks)): | |
| self.seg_cls.append((label[:, 0].astype(int)).tolist()) | |
| if include_class: | |
| j = (label[:, 0:1] == include_class_array).any(1) | |
| if semantic_masks: | |
| self.semantic_masks[i] = semantic_masks[j] | |
| if single_cls: # single-class training, merge all classes into 0 | |
| if semantic_masks: | |
| self.semantic_masks[i][:, 0] = 0 | |
| def __getitem__(self, index): | |
| index = self.indices[index] # linear, shuffled, or image_weights | |
| hyp = self.hyp | |
| mosaic = self.mosaic and random.random() < hyp['mosaic'] | |
| masks = [] | |
| if mosaic: | |
| # Load mosaic | |
| img, labels, segments, seg_cls, semantic_masks = self.load_mosaic(index) | |
| shapes = None | |
| # MixUp augmentation | |
| if random.random() < hyp["mixup"]: | |
| img, labels, segments, seg_cls, semantic_masks = mixup(img, labels, segments, seg_cls, semantic_masks, | |
| *self.load_mosaic(random.randint(0, self.n - 1))) | |
| else: | |
| # Load image | |
| img, (h0, w0), (h, w) = self.load_image(index) | |
| # Letterbox | |
| shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape | |
| img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment) | |
| shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling | |
| labels = self.labels[index].copy() | |
| # [array, array, ....], array.shape=(num_points, 2), xyxyxyxy | |
| segments = self.segments[index].copy() | |
| if len(segments): | |
| for i_s in range(len(segments)): | |
| segments[i_s] = xyn2xy( | |
| segments[i_s], | |
| ratio[0] * w, | |
| ratio[1] * h, | |
| padw=pad[0], | |
| padh=pad[1], | |
| ) | |
| seg_cls = self.seg_cls[index].copy() | |
| semantic_masks = self.semantic_masks[index].copy() | |
| #semantic_masks = [xyn2xy(x, ratio[0] * w, ratio[1] * h, padw = pad[0], padh = pad[1]) for x in semantic_masks] | |
| if len(semantic_masks): | |
| for ss in range(len(semantic_masks)): | |
| semantic_masks[ss] = xyn2xy( | |
| semantic_masks[ss], | |
| ratio[0] * w, | |
| ratio[1] * h, | |
| padw = pad[0], | |
| padh = pad[1], | |
| ) | |
| if labels.size: # normalized xywh to pixel xyxy format | |
| labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1]) | |
| if self.augment: | |
| img, labels, segments, semantic_masks = random_perspective( | |
| img, | |
| labels, | |
| segments=segments, | |
| semantic_masks = semantic_masks, | |
| degrees=hyp["degrees"], | |
| translate=hyp["translate"], | |
| scale=hyp["scale"], | |
| shear=hyp["shear"], | |
| perspective=hyp["perspective"]) | |
| nl = len(labels) # number of labels | |
| if nl: | |
| labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1e-3) | |
| if self.overlap: | |
| masks, sorted_idx = polygons2masks_overlap(img.shape[:2], | |
| segments, | |
| downsample_ratio=self.downsample_ratio) | |
| masks = masks[None] # (640, 640) -> (1, 640, 640) | |
| labels = labels[sorted_idx] | |
| else: | |
| masks = polygons2masks(img.shape[:2], segments, color=1, downsample_ratio=self.downsample_ratio) | |
| masks = (torch.from_numpy(masks) if len(masks) else torch.zeros(1 if self.overlap else nl, img.shape[0] // | |
| self.downsample_ratio, img.shape[1] // | |
| self.downsample_ratio)) | |
| semantic_masks = polygons2masks(img.shape[:2], semantic_masks, color = 1, downsample_ratio=self.downsample_ratio) | |
| #semantic_masks = polygons2masks(img.shape[:2], semantic_masks, color = 1, downsample_ratio=1) | |
| semantic_masks = torch.from_numpy(semantic_masks) | |
| # TODO: albumentations support | |
| if self.augment: | |
| # Albumentations | |
| # there are some augmentation that won't change boxes and masks, | |
| # so just be it for now. | |
| img, labels = self.albumentations(img, labels) | |
| nl = len(labels) # update after albumentations | |
| ns = len(semantic_masks) | |
| # HSV color-space | |
| augment_hsv(img, hgain=hyp["hsv_h"], sgain=hyp["hsv_s"], vgain=hyp["hsv_v"]) | |
| # Flip up-down | |
| if random.random() < hyp["flipud"]: | |
| img = np.flipud(img) | |
| if nl: | |
| labels[:, 2] = 1 - labels[:, 2] | |
| masks = torch.flip(masks, dims=[1]) | |
| if ns: | |
| semantic_masks = torch.flip(semantic_masks, dims = [1]) | |
| # Flip left-right | |
| if random.random() < hyp["fliplr"]: | |
| img = np.fliplr(img) | |
| if nl: | |
| labels[:, 1] = 1 - labels[:, 1] | |
| masks = torch.flip(masks, dims=[2]) | |
| if ns: | |
| semantic_masks = torch.flip(semantic_masks, dims = [2]) | |
| # Cutouts # labels = cutout(img, labels, p=0.5) | |
| labels_out = torch.zeros((nl, 6)) | |
| if nl: | |
| labels_out[:, 1:] = torch.from_numpy(labels) | |
| # Combine semantic masks | |
| semantic_seg_masks = torch.zeros((len(self.coco_ids), img.shape[0] // self.downsample_ratio, | |
| img.shape[1] // self.downsample_ratio), dtype = torch.uint8) | |
| #semantic_seg_masks = torch.zeros((len(self.coco_ids), img.shape[0], img.shape[1]), dtype = torch.uint8) | |
| for cls_id, semantic_mask in zip(seg_cls, semantic_masks): | |
| semantic_seg_masks[cls_id] = (semantic_seg_masks[cls_id].logical_or(semantic_mask)).int() | |
| # Convert | |
| img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB | |
| img = np.ascontiguousarray(img) | |
| return (torch.from_numpy(img), labels_out, self.im_files[index], shapes, masks, semantic_seg_masks) | |
| def load_mosaic(self, index): | |
| # YOLO 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic | |
| labels4, segments4, seg_cls, semantic_masks4 = [], [], [], [] | |
| s = self.img_size | |
| yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y | |
| # 3 additional image indices | |
| indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices | |
| for i, index in enumerate(indices): | |
| # Load image | |
| img, _, (h, w) = self.load_image(index) | |
| # place img in img4 | |
| if i == 0: # top left | |
| img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles | |
| x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image) | |
| x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image) | |
| elif i == 1: # top right | |
| x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc | |
| x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h | |
| elif i == 2: # bottom left | |
| x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h) | |
| x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h) | |
| elif i == 3: # bottom right | |
| x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h) | |
| x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h) | |
| img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] | |
| padw = x1a - x1b | |
| padh = y1a - y1b | |
| labels, segments, semantic_masks = self.labels[index].copy(), self.segments[index].copy(), self.semantic_masks[index].copy() | |
| if labels.size: | |
| labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format | |
| segments = [xyn2xy(x, w, h, padw, padh) for x in segments] | |
| semantic_masks = [xyn2xy(x, w, h, padw, padh) for x in semantic_masks] | |
| labels4.append(labels) | |
| segments4.extend(segments) | |
| seg_cls.extend(self.seg_cls[index].copy()) | |
| semantic_masks4.extend(semantic_masks) | |
| # Concat/clip labels | |
| labels4 = np.concatenate(labels4, 0) | |
| for i in range(len(semantic_masks4)): | |
| if i < len(segments4): | |
| np.clip(labels4[:, 1:][i], 0, 2 * s, out = labels4[:, 1:][i]) | |
| np.clip(segments4[i], 0, 2 * s, out = segments4[i]) | |
| np.clip(semantic_masks4[i], 0, 2 * s, out = semantic_masks4[i]) | |
| # img4, labels4 = replicate(img4, labels4) # replicate | |
| # 3 additional image indices | |
| # Augment | |
| img4, labels4, segments4, seg_cls, semantic_masks4 = copy_paste(img4, labels4, segments4, seg_cls, semantic_masks4, p=self.hyp["copy_paste"]) | |
| img4, labels4, segments4, semantic_masks4 = random_perspective(img4, | |
| labels4, | |
| segments4, | |
| semantic_masks4, | |
| degrees=self.hyp["degrees"], | |
| translate=self.hyp["translate"], | |
| scale=self.hyp["scale"], | |
| shear=self.hyp["shear"], | |
| perspective=self.hyp["perspective"], | |
| border=self.mosaic_border) # border to remove | |
| return img4, labels4, segments4, seg_cls, semantic_masks4 | |
| def cache_seg_labels(self, path = Path('./labels_stuff.cache'), prefix = ''): | |
| # Cache dataset labels, check images and read shapes | |
| x = {} # dict | |
| nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages | |
| desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..." | |
| with Pool(NUM_THREADS) as pool: | |
| pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.seg_files, repeat(prefix))), | |
| desc = desc, | |
| total = len(self.im_files), | |
| bar_format = TQDM_BAR_FORMAT) | |
| for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar: | |
| nm += nm_f | |
| nf += nf_f | |
| ne += ne_f | |
| nc += nc_f | |
| if im_file: | |
| x[im_file] = [lb, shape, segments] | |
| if msg: | |
| msgs.append(msg) | |
| pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt" | |
| pbar.close() | |
| if msgs: | |
| LOGGER.info('\n'.join(msgs)) | |
| if nf == 0: | |
| LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. {HELP_URL}') | |
| x['hash'] = get_hash(self.seg_files + self.im_files) | |
| x['results'] = nf, nm, ne, nc, len(self.im_files) | |
| x['msgs'] = msgs # warnings | |
| x['version'] = self.cache_version # cache version | |
| try: | |
| np.save(path, x) # save cache for next time | |
| path.with_suffix('.cache.npy').rename(path) # remove .npy suffix | |
| LOGGER.info(f'{prefix}New cache created: {path}') | |
| except Exception as e: | |
| LOGGER.warning(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # not writeable | |
| return x | |
| def collate_fn(batch): | |
| img, label, path, shapes, masks, semantic_masks = zip(*batch) # transposed | |
| batched_masks = torch.cat(masks, 0) | |
| for i, l in enumerate(label): | |
| l[:, 0] = i # add target image index for build_targets() | |
| return torch.stack(img, 0), torch.cat(label, 0), path, shapes, batched_masks, torch.stack(semantic_masks, 0) | |
| def polygon2mask(img_size, polygons, color=1, downsample_ratio=1): | |
| """ | |
| Args: | |
| img_size (tuple): The image size. | |
| polygons (np.ndarray): [N, M], N is the number of polygons, | |
| M is the number of points(Be divided by 2). | |
| """ | |
| mask = np.zeros(img_size, dtype=np.uint8) | |
| polygons = np.asarray(polygons) | |
| polygons = polygons.astype(np.int32) | |
| shape = polygons.shape | |
| polygons = polygons.reshape(shape[0], -1, 2) | |
| cv2.fillPoly(mask, polygons, color=color) | |
| nh, nw = (img_size[0] // downsample_ratio, img_size[1] // downsample_ratio) | |
| # NOTE: fillPoly firstly then resize is trying the keep the same way | |
| # of loss calculation when mask-ratio=1. | |
| mask = cv2.resize(mask, (nw, nh)) | |
| return mask | |
| def polygons2masks(img_size, polygons, color, downsample_ratio=1): | |
| """ | |
| Args: | |
| img_size (tuple): The image size. | |
| polygons (list[np.ndarray]): each polygon is [N, M], | |
| N is the number of polygons, | |
| M is the number of points(Be divided by 2). | |
| """ | |
| masks = [] | |
| for si in range(len(polygons)): | |
| mask = polygon2mask(img_size, [polygons[si].reshape(-1)], color, downsample_ratio) | |
| masks.append(mask) | |
| return np.array(masks) | |
| def polygons2masks_overlap(img_size, segments, downsample_ratio=1): | |
| """Return a (640, 640) overlap mask.""" | |
| masks = np.zeros((img_size[0] // downsample_ratio, img_size[1] // downsample_ratio), | |
| dtype=np.int32 if len(segments) > 255 else np.uint8) | |
| areas = [] | |
| ms = [] | |
| for si in range(len(segments)): | |
| mask = polygon2mask( | |
| img_size, | |
| [segments[si].reshape(-1)], | |
| downsample_ratio=downsample_ratio, | |
| color=1, | |
| ) | |
| ms.append(mask) | |
| areas.append(mask.sum()) | |
| areas = np.asarray(areas) | |
| index = np.argsort(-areas) | |
| ms = np.array(ms)[index] | |
| for i in range(len(segments)): | |
| mask = ms[i] * (i + 1) | |
| masks = masks + mask | |
| masks = np.clip(masks, a_min=0, a_max=i + 1) | |
| return masks, index | |