Spaces:
Paused
Paused
| import sys | |
| sys.path.append("../") | |
| import os | |
| import re | |
| import time | |
| import datetime | |
| from copy import deepcopy | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from PIL import Image | |
| from PIL.ImageOps import exif_transpose | |
| from safetensors.torch import load_file | |
| from utils.flow_utils import flow_to_image, resize_flow | |
| from flowgen.models import UnetGenerator | |
| from flowdiffusion.pipeline import FlowDiffusionPipeline | |
| LENGTH = 512 | |
| FLOWGAN_RESOLUTION = [256, 256] # HxW | |
| FLOWDIFFUSION_RESOLUTION = [512, 512] # HxW | |
| def process_img(image): | |
| if image["composite"] is not None and not np.all(image["composite"] == 0): | |
| original_image = Image.fromarray(image["composite"]).resize((LENGTH, LENGTH), Image.BICUBIC) | |
| original_image = np.array(exif_transpose(original_image)) | |
| return original_image, [], gr.Image(value=deepcopy(original_image), interactive=False) | |
| else: | |
| return ( | |
| gr.Image(value=None, interactive=False), | |
| [], | |
| gr.Image(value=None, interactive=False), | |
| ) | |
| def get_points(img, sel_pix, evt: gr.SelectData): | |
| sel_pix.append(evt.index) | |
| print(sel_pix) | |
| points = [] | |
| for idx, point in enumerate(sel_pix): | |
| if idx % 2 == 0: | |
| cv2.circle(img, tuple(point), 4, (255, 0, 0), -1) | |
| else: | |
| cv2.circle(img, tuple(point), 4, (0, 0, 255), -1) | |
| points.append(tuple(point)) | |
| if len(points) == 2: | |
| cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 2, tipLength=0.5) | |
| points = [] | |
| img = img if isinstance(img, np.ndarray) else np.array(img) | |
| return img | |
| def display_points(img, predefined_points, save_results): | |
| if predefined_points != "": | |
| predefined_points = predefined_points.split() | |
| predefined_points = [int(re.sub(r'[^0-9]', '', point)) for point in predefined_points] | |
| processed_points = [] | |
| for i, point in enumerate(predefined_points): | |
| if i % 2 == 0: | |
| processed_points.append([point, predefined_points[i+1]]) | |
| selected_points = processed_points | |
| print(selected_points) | |
| points = [] | |
| for idx, point in enumerate(selected_points): | |
| if idx % 2 == 0: | |
| cv2.circle(img, tuple(point), 4, (255, 0, 0), -1) | |
| else: | |
| cv2.circle(img, tuple(point), 4, (0, 0, 255), -1) | |
| points.append(tuple(point)) | |
| if len(points) == 2: | |
| cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 2, tipLength=0.5) | |
| points = [] | |
| img = img if isinstance(img, np.ndarray) else np.array(img) | |
| if save_results: | |
| if not os.path.isdir("results/drag_inst_viz"): | |
| os.makedirs("results/drag_inst_viz") | |
| save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") | |
| to_save_img = Image.fromarray(img) | |
| to_save_img.save(f"results/drag_inst_viz/{save_prefix}.png") | |
| return img | |
| def undo_points_image(original_image): | |
| if original_image is not None: | |
| return original_image, [] | |
| else: | |
| return gr.Image(value=None, interactive=False), [] | |
| def clear_all(): | |
| return ( | |
| gr.Image(value=None, interactive=True), | |
| gr.Image(value=None, interactive=False), | |
| gr.Image(value=None, interactive=False), | |
| [], | |
| None | |
| ) | |
| class InstantDragPipeline: | |
| def __init__(self, seed=9999, device="cuda", dtype=torch.float16): | |
| self.seed = seed | |
| self.device = device | |
| self.dtype = dtype | |
| self.generator = torch.Generator(device=device).manual_seed(seed) | |
| self.flowgen_ckpt, self.flowdiffusion_ckpt = None, None | |
| self.model_config = dict() | |
| def build_model(self): | |
| print("Building model...") | |
| if self.flowgen_ckpt != self.model_config["flowgen_ckpt"]: | |
| self.flowgen = UnetGenerator(input_nc=5, output_nc=2) | |
| self.flowgen.load_state_dict( | |
| load_file(os.path.join("checkpoints/", self.model_config["flowgen_ckpt"]), device="cpu") | |
| ) | |
| self.flowgen.to(self.device) | |
| self.flowgen.eval() | |
| self.flowgen_ckpt = self.model_config["flowgen_ckpt"] | |
| if self.flowdiffusion_ckpt != self.model_config["flowdiffusion_ckpt"]: | |
| self.flowdiffusion = FlowDiffusionPipeline.from_pretrained( | |
| os.path.join("checkpoints/", self.model_config["flowdiffusion_ckpt"]), | |
| torch_dtype=self.dtype, | |
| safety_checker=None | |
| ) | |
| self.flowdiffusion.to(self.device) | |
| self.flowdiffusion_ckpt = self.model_config["flowdiffusion_ckpt"] | |
| def drag(self, original_image, selected_points, save_results): | |
| scale = self.model_config["flowgen_output_scale"] | |
| original_image = torch.tensor(original_image).permute(2, 0, 1).unsqueeze(0).float() # 1, 3, 512, 512 | |
| original_image = 2 * (original_image / 255.) - 1 # Normalize to [-1, 1] | |
| original_image = original_image.to(self.device) | |
| source_points = [] | |
| target_points = [] | |
| for idx, point in enumerate(selected_points): | |
| cur_point = torch.tensor([point[0], point[1]]) # x, y | |
| if idx % 2 == 0: | |
| source_points.append(cur_point) | |
| else: | |
| target_points.append(cur_point) | |
| torch.cuda.synchronize() | |
| start_time = time.time() | |
| # Generate sparse flow vectors | |
| point_vector_map = torch.zeros((1, 2, LENGTH, LENGTH)) | |
| for source_point, target_point in zip(source_points, target_points): | |
| cur_x, cur_y = source_point[0], source_point[1] | |
| target_x, target_y = target_point[0], target_point[1] | |
| vec_x = target_x - cur_x | |
| vec_y = target_y - cur_y | |
| point_vector_map[0, 0, int(cur_y), int(cur_x)] = vec_x | |
| point_vector_map[0, 1, int(cur_y), int(cur_x)] = vec_y | |
| point_vector_map = point_vector_map.to(self.device) | |
| # Sample-wise normalize the flow vectors | |
| factor_x = torch.amax(torch.abs(point_vector_map[:, 0, :, :]), dim=(1, 2)).view(-1, 1, 1).to(self.device) | |
| factor_y = torch.amax(torch.abs(point_vector_map[:, 1, :, :]), dim=(1, 2)).view(-1, 1, 1).to(self.device) | |
| if factor_x >= 1e-8: # Avoid division by zero | |
| point_vector_map[:, 0, :, :] /= factor_x | |
| if factor_y >= 1e-8: # Avoid division by zero | |
| point_vector_map[:, 1, :, :] /= factor_y | |
| with torch.inference_mode(): | |
| gan_input_image = F.interpolate(original_image, size=FLOWGAN_RESOLUTION, mode="bicubic") # 256 x 256 | |
| point_vector_map = F.interpolate(point_vector_map, size=FLOWGAN_RESOLUTION, mode="bicubic") # 256 x 256 | |
| gan_input = torch.cat([gan_input_image, point_vector_map], dim=1) | |
| flow = self.flowgen(gan_input) # -1 ~ 1 | |
| if scale == -1.0: | |
| flow[:, 0, :, :] *= 1.0 / torch.amax(torch.abs(flow[:, 0, :, :]), dim=(1, 2)).view(-1, 1, 1) # force the range to be [-1 ~ 1] | |
| flow[:, 1, :, :] *= 1.0 / torch.amax(torch.abs(flow[:, 1, :, :]), dim=(1, 2)).view(-1, 1, 1) # force the range to be [-1 ~ 1] | |
| else: | |
| flow[:, 0, :, :] *= scale # manually adjust the scale | |
| flow[:, 1, :, :] *= scale # manually adjust the scale | |
| if factor_x >= 1e-8: | |
| flow[:, 0, :, :] *= factor_x * (FLOWGAN_RESOLUTION[1] / original_image.shape[3]) # width | |
| else: | |
| flow[:, 0, :, :] *= 0 | |
| if factor_y >= 1e-8: | |
| flow[:, 1, :, :] *= factor_y * (FLOWGAN_RESOLUTION[0] / original_image.shape[2]) # height | |
| else: | |
| flow[:, 1, :, :] *= 0 | |
| resized_flow = resize_flow(flow, (FLOWDIFFUSION_RESOLUTION[0]//8, FLOWDIFFUSION_RESOLUTION[1]//8), scale_type="normalize_fixed") | |
| kwargs = { | |
| "image": original_image.to(self.dtype), | |
| "flow": resized_flow.to(self.dtype), | |
| "num_inference_steps": self.model_config['n_inference_step'], | |
| "image_guidance_scale": self.model_config['image_guidance'], | |
| "flow_guidance_scale": self.model_config['flow_guidance'], | |
| "generator": self.generator, | |
| } | |
| edited_image = self.flowdiffusion(**kwargs).images[0] | |
| end_time = time.time() | |
| inference_time = end_time - start_time | |
| print(f"Inference Time: {inference_time} seconds") | |
| if save_results: | |
| save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") | |
| if not os.path.isdir("results/flows"): | |
| os.makedirs("results/flows") | |
| np.save(f"results/flows/{save_prefix}.npy", flow[0].detach().cpu().numpy()) | |
| if not os.path.isdir("results/flow_visualized"): | |
| os.makedirs("results/flow_visualized") | |
| flow_to_image(flow[0].detach()).save(f"results/flow_visualized/{save_prefix}.png") | |
| if not os.path.isdir("results/edited_images"): | |
| os.makedirs("results/edited_images") | |
| edited_image.save(f"results/edited_images/{save_prefix}.png") | |
| if not os.path.isdir("results/drag_instructions"): | |
| os.makedirs("results/drag_instructions") | |
| with open(f"results/drag_instructions/{save_prefix}.txt", "w") as f: | |
| f.write(str(selected_points)) | |
| edited_image = np.array(edited_image) | |
| return edited_image | |
| def run(self, original_image, selected_points, | |
| flowgen_ckpt, flowdiffusion_ckpt, image_guidance, flow_guidance, flowgen_output_scale, | |
| num_steps, save_results): | |
| self.model_config = { | |
| "flowgen_ckpt": flowgen_ckpt, | |
| "flowdiffusion_ckpt": flowdiffusion_ckpt, | |
| "image_guidance": image_guidance, | |
| "flow_guidance": flow_guidance, | |
| "flowgen_output_scale": flowgen_output_scale, | |
| "n_inference_step": num_steps | |
| } | |
| self.build_model() | |
| edited_image = self.drag(original_image, selected_points, save_results) | |
| return edited_image |