Spaces:
Paused
Paused
| import os | |
| import comfy | |
| import comfy.utils | |
| import cv2 | |
| import folder_paths | |
| import numpy as np | |
| import torch | |
| from comfy import model_management | |
| from PIL import Image | |
| from ..log import NullWriter, log | |
| from ..utils import get_model_path, np2tensor, pil2tensor, tensor2np | |
| class MTB_LoadFaceEnhanceModel: | |
| """Loads a GFPGan or RestoreFormer model for face enhancement.""" | |
| def __init__(self) -> None: | |
| pass | |
| def get_models_root(cls): | |
| fr = get_model_path("face_restore") | |
| # fr = Path(folder_paths.models_dir) / "face_restore" | |
| if fr.exists(): | |
| return (fr, None) | |
| um = get_model_path("upscale_models") | |
| return (fr, um) if um.exists() else (None, None) | |
| def get_models(cls): | |
| fr_models_path, um_models_path = cls.get_models_root() | |
| if fr_models_path is None and um_models_path is None: | |
| if not hasattr(cls, "_warned"): | |
| log.warning("Face restoration models not found.") | |
| cls._warned = True | |
| return [] | |
| if not fr_models_path.exists(): | |
| # log.warning( | |
| # f"No Face Restore checkpoints found at {fr_models_path} (if you've used mtb before these checkpoints were saved in upscale_models before)" | |
| # ) | |
| # log.warning( | |
| # "For now we fallback to upscale_models but this will be removed in a future version" | |
| # ) | |
| if um_models_path.exists(): | |
| return [ | |
| x | |
| for x in um_models_path.iterdir() | |
| if x.name.endswith(".pth") | |
| and ("GFPGAN" in x.name or "RestoreFormer" in x.name) | |
| ] | |
| return [] | |
| return [ | |
| x | |
| for x in fr_models_path.iterdir() | |
| if x.name.endswith(".pth") | |
| and ("GFPGAN" in x.name or "RestoreFormer" in x.name) | |
| ] | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "model_name": ( | |
| [x.name for x in cls.get_models()], | |
| {"default": "None"}, | |
| ), | |
| "upscale": ("INT", {"default": 1}), | |
| }, | |
| "optional": {"bg_upsampler": ("UPSCALE_MODEL", {"default": None})}, | |
| } | |
| RETURN_TYPES = ("FACEENHANCE_MODEL",) | |
| RETURN_NAMES = ("model",) | |
| FUNCTION = "load_model" | |
| CATEGORY = "mtb/facetools" | |
| DEPRECATED = True | |
| def load_model(self, model_name, upscale=2, bg_upsampler=None): | |
| from gfpgan import GFPGANer | |
| basic = "RestoreFormer" not in model_name | |
| fr_root, um_root = self.get_models_root() | |
| if bg_upsampler is not None: | |
| log.warning( | |
| f"Upscale value overridden to {bg_upsampler.scale} from bg_upsampler" | |
| ) | |
| upscale = bg_upsampler.scale | |
| bg_upsampler = BGUpscaleWrapper(bg_upsampler) | |
| sys.stdout = NullWriter() | |
| model = GFPGANer( | |
| model_path=( | |
| (fr_root if fr_root.exists() else um_root) / model_name | |
| ).as_posix(), | |
| upscale=upscale, | |
| arch="clean" | |
| if basic | |
| else "RestoreFormer", # or original for v1.0 only | |
| channel_multiplier=2, # 1 for v1.0 only | |
| bg_upsampler=bg_upsampler, | |
| ) | |
| sys.stdout = sys.__stdout__ | |
| return (model,) | |
| class BGUpscaleWrapper: | |
| def __init__(self, upscale_model) -> None: | |
| self.upscale_model = upscale_model | |
| def enhance(self, img: Image.Image, outscale=2): | |
| device = model_management.get_torch_device() | |
| self.upscale_model.to(device) | |
| tile = 128 + 64 | |
| overlap = 8 | |
| imgt = np2tensor(img) | |
| imgt = imgt.movedim(-1, -3).to(device) | |
| steps = imgt.shape[0] * comfy.utils.get_tiled_scale_steps( | |
| imgt.shape[3], | |
| imgt.shape[2], | |
| tile_x=tile, | |
| tile_y=tile, | |
| overlap=overlap, | |
| ) | |
| log.debug(f"Steps: {steps}") | |
| pbar = comfy.utils.ProgressBar(steps) | |
| s = comfy.utils.tiled_scale( | |
| imgt, | |
| lambda a: self.upscale_model(a), | |
| tile_x=tile, | |
| tile_y=tile, | |
| overlap=overlap, | |
| upscale_amount=self.upscale_model.scale, | |
| pbar=pbar, | |
| ) | |
| self.upscale_model.cpu() | |
| s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0) | |
| return (tensor2np(s)[0],) | |
| import sys | |
| class MTB_RestoreFace: | |
| """Uses GFPGan to restore faces""" | |
| def __init__(self) -> None: | |
| pass | |
| RETURN_TYPES = ("IMAGE",) | |
| FUNCTION = "restore" | |
| CATEGORY = "mtb/facetools" | |
| DEPRECATED = True | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "image": ("IMAGE",), | |
| "model": ("FACEENHANCE_MODEL",), | |
| # Input are aligned faces | |
| "aligned": ("BOOLEAN", {"default": False}), | |
| # Only restore the center face | |
| "only_center_face": ("BOOLEAN", {"default": False}), | |
| # Adjustable weights | |
| "weight": ("FLOAT", {"default": 0.5}), | |
| "save_tmp_steps": ("BOOLEAN", {"default": True}), | |
| }, | |
| "optional": { | |
| "preserve_alpha": ("BOOLEAN", {"default": True}), | |
| }, | |
| } | |
| def do_restore( | |
| self, | |
| image: torch.Tensor, | |
| model, | |
| aligned, | |
| only_center_face, | |
| weight, | |
| save_tmp_steps, | |
| preserve_alpha: bool = False, | |
| ) -> torch.Tensor: | |
| pimage = tensor2np(image)[0] | |
| width, height = pimage.shape[1], pimage.shape[0] | |
| source_img = cv2.cvtColor(np.array(pimage), cv2.COLOR_RGB2BGR) | |
| alpha_channel = None | |
| if ( | |
| preserve_alpha and image.size(-1) == 4 | |
| ): # Check if the image has an alpha channel | |
| alpha_channel = pimage[:, :, 3] | |
| pimage = pimage[:, :, :3] # Remove alpha channel for processing | |
| sys.stdout = NullWriter() | |
| cropped_faces, restored_faces, restored_img = model.enhance( | |
| source_img, | |
| has_aligned=aligned, | |
| only_center_face=only_center_face, | |
| paste_back=True, | |
| # TODO: weight has no effect in 1.3 and 1.4 (only tested these for now...) | |
| weight=weight, | |
| ) | |
| sys.stdout = sys.__stdout__ | |
| log.warning(f"Weight value has no effect for now. (value: {weight})") | |
| if save_tmp_steps: | |
| self.save_intermediate_images( | |
| cropped_faces, restored_faces, height, width | |
| ) | |
| output = None | |
| if restored_img is not None: | |
| restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB) | |
| output = Image.fromarray(restored_img) | |
| if alpha_channel is not None: | |
| alpha_resized = Image.fromarray(alpha_channel).resize( | |
| output.size, Image.LANCZOS | |
| ) | |
| output.putalpha(alpha_resized) | |
| # imwrite(restored_img, save_restore_path) | |
| return pil2tensor(output) | |
| def restore( | |
| self, | |
| image: torch.Tensor, | |
| model, | |
| aligned=False, | |
| only_center_face=False, | |
| weight=0.5, | |
| save_tmp_steps=True, | |
| preserve_alpha: bool = False, | |
| ) -> tuple[torch.Tensor]: | |
| out = [ | |
| self.do_restore( | |
| image[i], | |
| model, | |
| aligned, | |
| only_center_face, | |
| weight, | |
| save_tmp_steps, | |
| preserve_alpha, | |
| ) | |
| for i in range(image.size(0)) | |
| ] | |
| return (torch.cat(out, dim=0),) | |
| def get_step_image_path(self, step, idx): | |
| ( | |
| full_output_folder, | |
| filename, | |
| counter, | |
| _subfolder, | |
| _filename_prefix, | |
| ) = folder_paths.get_save_image_path( | |
| f"{step}_{idx:03}", | |
| folder_paths.temp_directory, | |
| ) | |
| file = f"{filename}_{counter:05}_.png" | |
| return os.path.join(full_output_folder, file) | |
| def save_intermediate_images( | |
| self, cropped_faces, restored_faces, height, width | |
| ): | |
| for idx, (cropped_face, restored_face) in enumerate( | |
| zip(cropped_faces, restored_faces, strict=False) | |
| ): | |
| face_id = idx + 1 | |
| file = self.get_step_image_path("cropped_faces", face_id) | |
| cv2.imwrite(file, cropped_face) | |
| file = self.get_step_image_path("cropped_faces_restored", face_id) | |
| cv2.imwrite(file, restored_face) | |
| file = self.get_step_image_path("cropped_faces_compare", face_id) | |
| # save comparison image | |
| cmp_img = np.concatenate((cropped_face, restored_face), axis=1) | |
| cv2.imwrite(file, cmp_img) | |
| __nodes__ = [MTB_RestoreFace, MTB_LoadFaceEnhanceModel] | |