Spaces:
Build error
Build error
| from typing import Mapping, Any, Tuple, Callable | |
| import importlib | |
| import os | |
| from urllib.parse import urlparse | |
| import torch | |
| from torch import Tensor | |
| from torch.nn import functional as F | |
| import numpy as np | |
| from torch.hub import download_url_to_file, get_dir | |
| def get_obj_from_str(string: str, reload: bool=False) -> Any: | |
| module, cls = string.rsplit(".", 1) | |
| if reload: | |
| module_imp = importlib.import_module(module) | |
| importlib.reload(module_imp) | |
| return getattr(importlib.import_module(module, package=None), cls) | |
| def instantiate_from_config(config: Mapping[str, Any]) -> Any: | |
| if not "target" in config: | |
| raise KeyError("Expected key `target` to instantiate.") | |
| # import ipdb; ipdb.set_trace() | |
| return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
| def wavelet_blur(image: Tensor, radius: int): | |
| """ | |
| Apply wavelet blur to the input tensor. | |
| """ | |
| # input shape: (1, 3, H, W) | |
| # convolution kernel | |
| kernel_vals = [ | |
| [0.0625, 0.125, 0.0625], | |
| [0.125, 0.25, 0.125], | |
| [0.0625, 0.125, 0.0625], | |
| ] | |
| kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) | |
| # add channel dimensions to the kernel to make it a 4D tensor | |
| kernel = kernel[None, None] | |
| # repeat the kernel across all input channels | |
| kernel = kernel.repeat(3, 1, 1, 1) | |
| image = F.pad(image, (radius, radius, radius, radius), mode='replicate') | |
| # apply convolution | |
| output = F.conv2d(image, kernel, groups=3, dilation=radius) | |
| return output | |
| def wavelet_decomposition(image: Tensor, levels=5): | |
| """ | |
| Apply wavelet decomposition to the input tensor. | |
| This function only returns the low frequency & the high frequency. | |
| """ | |
| high_freq = torch.zeros_like(image) | |
| for i in range(levels): | |
| radius = 2 ** i | |
| low_freq = wavelet_blur(image, radius) | |
| high_freq += (image - low_freq) | |
| image = low_freq | |
| return high_freq, low_freq | |
| def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): | |
| """ | |
| Apply wavelet decomposition, so that the content will have the same color as the style. | |
| """ | |
| # calculate the wavelet decomposition of the content feature | |
| content_high_freq, content_low_freq = wavelet_decomposition(content_feat) | |
| del content_low_freq | |
| # calculate the wavelet decomposition of the style feature | |
| style_high_freq, style_low_freq = wavelet_decomposition(style_feat) | |
| del style_high_freq | |
| # reconstruct the content feature with the style's high frequency | |
| return content_high_freq + style_low_freq | |
| # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/ | |
| def load_file_from_url(url, model_dir=None, progress=True, file_name=None): | |
| """Load file form http url, will download models if necessary. | |
| Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py | |
| Args: | |
| url (str): URL to be downloaded. | |
| model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. | |
| Default: None. | |
| progress (bool): Whether to show the download progress. Default: True. | |
| file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. | |
| Returns: | |
| str: The path to the downloaded file. | |
| """ | |
| if model_dir is None: # use the pytorch hub_dir | |
| hub_dir = get_dir() | |
| model_dir = os.path.join(hub_dir, 'checkpoints') | |
| os.makedirs(model_dir, exist_ok=True) | |
| parts = urlparse(url) | |
| filename = os.path.basename(parts.path) | |
| if file_name is not None: | |
| filename = file_name | |
| cached_file = os.path.abspath(os.path.join(model_dir, filename)) | |
| if not os.path.exists(cached_file): | |
| print(f'Downloading: "{url}" to {cached_file}\n') | |
| download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) | |
| return cached_file | |
| def sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]: | |
| hi_list = list(range(0, h - tile_size + 1, tile_stride)) | |
| if (h - tile_size) % tile_stride != 0: | |
| hi_list.append(h - tile_size) | |
| wi_list = list(range(0, w - tile_size + 1, tile_stride)) | |
| if (w - tile_size) % tile_stride != 0: | |
| wi_list.append(w - tile_size) | |
| coords = [] | |
| for hi in hi_list: | |
| for wi in wi_list: | |
| coords.append((hi, hi + tile_size, wi, wi + tile_size)) | |
| return coords | |
| # https://github.com/csslc/CCSR/blob/main/model/q_sampler.py#L503 | |
| def gaussian_weights(tile_width: int, tile_height: int) -> np.ndarray: | |
| """Generates a gaussian mask of weights for tile contributions""" | |
| latent_width = tile_width | |
| latent_height = tile_height | |
| var = 0.01 | |
| midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1 | |
| x_probs = [ | |
| np.exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / np.sqrt(2 * np.pi * var) | |
| for x in range(latent_width)] | |
| midpoint = latent_height / 2 | |
| y_probs = [ | |
| np.exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / np.sqrt(2 * np.pi * var) | |
| for y in range(latent_height)] | |
| weights = np.outer(y_probs, x_probs) | |
| return weights | |
| COUNT_VRAM = bool(os.environ.get("COUNT_VRAM", False)) | |
| def count_vram_usage(func: Callable) -> Callable: | |
| if not COUNT_VRAM: | |
| return func | |
| def wrapper(*args, **kwargs): | |
| peak_before = torch.cuda.max_memory_allocated() / (1024 ** 3) | |
| ret = func(*args, **kwargs) | |
| torch.cuda.synchronize() | |
| peak_after = torch.cuda.max_memory_allocated() / (1024 ** 3) | |
| print(f"VRAM peak before {func.__name__}: {peak_before:.5f} GB, after: {peak_after:.5f} GB") | |
| return ret | |
| return wrapper |