Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import gradio as gr | |
| import spaces | |
| import os | |
| import random | |
| import subprocess | |
| import torch | |
| from PIL import Image | |
| import cv2 | |
| from huggingface_hub import login | |
| from diffusers import FluxControlNetPipeline, FluxControlNetModel | |
| from diffusers.models import FluxMultiControlNetModel | |
| import warnings | |
| from typing import Tuple | |
| """ | |
| FLUX‑1 ControlNet demo | |
| ---------------------- | |
| This script rebuilds the Gradio interface shown in your screenshot with **one** control‑image upload | |
| slot and integrates the FLUX.1‑dev‑ControlNet‑Union‑Pro model. | |
| Key points | |
| ~~~~~~~~~~ | |
| * Single *control image* input (left). | |
| * *Result* and *Pre‑processed Cond* previews side‑by‑side (center & right). | |
| * *Prompt* textbox plus a dedicated **ControlNet** panel for choosing the mode and strength. | |
| * Seed handling with optional randomisation. | |
| * Advanced sliders for *Guidance scale* and *Inference steps*. | |
| * Works on CUDA (bfloat16) or CPU (float32). | |
| * Minimal Canny preview implementation when the *canny* mode is selected (extend as you like for the | |
| other modes). | |
| Before running, set the `HUGGINGFACE_TOKEN` environment variable **or** call | |
| `login("<YOUR_HF_TOKEN>")` explicitly. | |
| """ | |
| subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) | |
| # -------------------------------------------------- | |
| # Model & pipeline setup | |
| # -------------------------------------------------- | |
| HF_TOKEN = os.getenv("HF_TOKEN_NEW") | |
| login(HF_TOKEN) | |
| # If you prefer to hard‑code the token, uncomment: | |
| # login("hf_your_token_here") | |
| BASE_MODEL = "black-forest-labs/FLUX.1-dev" | |
| CONTROLNET_MODEL = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| print(1) | |
| controlnet_single = FluxControlNetModel.from_pretrained( | |
| CONTROLNET_MODEL, torch_dtype=dtype | |
| ) | |
| print(2) | |
| controlnet = FluxMultiControlNetModel([controlnet_single]) | |
| print(3) | |
| pipe = FluxControlNetPipeline.from_pretrained( | |
| BASE_MODEL, controlnet=controlnet, torch_dtype=dtype | |
| ).to(device) | |
| print(4) | |
| pipe.set_progress_bar_config(disable=True) | |
| print(5) | |
| # -------------------------------------------------- | |
| # UI ‑> model value mapping | |
| # -------------------------------------------------- | |
| MODE_MAPPING = { | |
| "canny": 0, | |
| "tile": 1, | |
| "depth": 2, | |
| "blur": 3, | |
| "pose": 4, | |
| "gray": 5, | |
| "low quality": 6, | |
| } | |
| MAX_SEED = 100 | |
| # ----------------------------------------------------------------------------- | |
| # Preview helpers – one small, self‑contained function per mode | |
| # ----------------------------------------------------------------------------- | |
| def _preview_canny( | |
| pil_img: Image.Image, canny_threshold_1: int, canny_threshold_2: int | |
| ) -> Image.Image: | |
| """Fast Canny‑edge preview (already implemented).""" | |
| arr = np.array(pil_img.convert("RGB")) | |
| blurred = cv2.GaussianBlur(arr, (5, 5), 1.4) | |
| edges = cv2.Canny(blurred, threshold1=canny_threshold_1, threshold2=canny_threshold_2) | |
| edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) | |
| return Image.fromarray(edges_rgb) | |
| # ――― tile ―――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― # | |
| def _preview_tile(pil_img: Image.Image, grid: Tuple[int, int] = (2, 2)) -> Image.Image: | |
| """Replicates *pil_img* into an *n×m* tiled grid (default 2×2). | |
| This offers a quick visual hint of what a *tiling* control mode will do | |
| (repeatable textures, etc.).""" | |
| cols, rows = grid | |
| img_rgb = pil_img.convert("RGB") | |
| w, h = img_rgb.size | |
| tiled = Image.new("RGB", (w * cols, h * rows)) | |
| for c in range(cols): | |
| for r in range(rows): | |
| tiled.paste(img_rgb, (c * w, r * h)) | |
| return tiled | |
| # ――― depth ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― # | |
| def _preview_depth(pil_img: Image.Image) -> Image.Image: | |
| """Very rough *depth* proxy using the Laplacian and a colormap. | |
| ▸ Convert to gray | |
| ▸ Run Laplacian to highlight depth‑like gradients | |
| ▸ Apply a TURBO colormap to mimic depth heat‑map appearance""" | |
| gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY) | |
| lap = cv2.Laplacian(gray, cv2.CV_16S, ksize=3) | |
| depth = cv2.convertScaleAbs(lap) | |
| depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_TURBO) | |
| return Image.fromarray(depth_color) | |
| # ――― blur ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― # | |
| def _preview_blur(pil_img: Image.Image, ksize: int = 15) -> Image.Image: | |
| """Gaussian blur preview. | |
| A single, relatively large kernel is enough for UI illustration.""" | |
| if ksize % 2 == 0: | |
| ksize += 1 # kernel must be odd | |
| blurred = cv2.GaussianBlur(np.array(pil_img), (ksize, ksize), sigmaX=0) | |
| return Image.fromarray(blurred) | |
| # ――― pose ―――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― # | |
| def _preview_pose(pil_img: Image.Image) -> Image.Image: | |
| """Attempt a lightweight 2‑D pose overlay using *mediapipe* if available. | |
| If *mediapipe* is not installed (or CPU inference fails), we gracefully | |
| fallback to an edge‑map preview so the UI never crashes.""" | |
| try: | |
| import mediapipe as mp # type: ignore | |
| mp_pose = mp.solutions.pose | |
| mp_drawing = mp.solutions.drawing_utils | |
| img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) | |
| with mp_pose.Pose(static_image_mode=True) as pose_estimator: | |
| results = pose_estimator.process( | |
| img_bgr[..., ::-1] | |
| ) # Mediapipe expects RGB | |
| annotated = img_bgr.copy() | |
| if results.pose_landmarks: | |
| mp_drawing.draw_landmarks( | |
| annotated, results.pose_landmarks, mp_pose.POSE_CONNECTIONS | |
| ) | |
| annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) | |
| return Image.fromarray(annotated_rgb) | |
| except Exception as exc: # pragma: no cover – any import / runtime error | |
| warnings.warn( | |
| f"Pose preview failed ({exc!s}); falling back to Canny.", RuntimeWarning | |
| ) | |
| # Return an edge map as a sensible fallback rather than exploding the UI | |
| return _preview_canny(pil_img, 100, 200) | |
| # ――― gray ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― # | |
| def _preview_gray(pil_img: Image.Image) -> Image.Image: | |
| """Simple grayscale conversion, but keep a 3‑channel RGB image so the UI | |
| widget pipeline stays consistent.""" | |
| gray = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2GRAY) | |
| gray_rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) | |
| return Image.fromarray(gray_rgb) | |
| # ――― low quality ――――――――――――――――――――――――――――――――――――――――――――――――――――――――― # | |
| def _preview_low_quality(pil_img: Image.Image, factor: int = 8) -> Image.Image: | |
| """Mimic a low‑quality thumbnail: aggressively downsample then upscale. | |
| The default *factor* (8×) is chosen to make artefacts obvious.""" | |
| img_rgb = pil_img.convert("RGB") | |
| w, h = img_rgb.size | |
| small = img_rgb.resize((max(1, w // factor), max(1, h // factor)), Image.BILINEAR) | |
| low_q = small.resize( | |
| (w, h), Image.NEAREST | |
| ) # upsample w/ Nearest to exaggerate blocks | |
| return low_q | |
| # ----------------------------------------------------------------------------- | |
| # Master dispatch | |
| # ----------------------------------------------------------------------------- | |
| def _make_preview( | |
| control_image: Image.Image, | |
| mode: str, | |
| canny_threshold_1: int = 100, | |
| canny_threshold_2: int = 200, | |
| ) -> Image.Image: | |
| """Return a *quick‑n‑dirty* preview image for the requested *mode*. | |
| Parameters | |
| ---------- | |
| control_image : PIL.Image | |
| The input image selected by the user. | |
| mode : str | |
| One of the keys of :data:`MODE_MAPPING`. | |
| canny_threshold_1 / 2 : int, optional | |
| Only used if *mode* is "canny" (passed straight to OpenCV Canny). | |
| """ | |
| mode = mode.lower() | |
| if mode not in MODE_MAPPING: | |
| warnings.warn(f"Unknown preview mode '{mode}'. Returning untouched image.") | |
| return control_image | |
| if mode == "canny": | |
| return _preview_canny(control_image, canny_threshold_1, canny_threshold_2) | |
| if mode == "tile": | |
| return _preview_tile(control_image) | |
| if mode == "depth": | |
| return _preview_depth(control_image) | |
| if mode == "blur": | |
| return _preview_blur(control_image) | |
| if mode == "pose": | |
| return _preview_pose(control_image) | |
| if mode == "gray": | |
| return _preview_gray(control_image) | |
| if mode == "low quality": | |
| return _preview_low_quality(control_image) | |
| # Fallback – should never happen due to early mode check | |
| return control_image | |
| # -------------------------------------------------- | |
| # Inference function | |
| # -------------------------------------------------- | |
| def infer( | |
| control_image: Image.Image, | |
| prompt: str, | |
| mode: str, | |
| control_strength: float, | |
| seed: int, | |
| randomize_seed: bool, | |
| guidance_scale: float, | |
| num_inference_steps: int, | |
| canny_threshold_1: int, | |
| canny_threshold_2: int, | |
| ): | |
| if control_image is None: | |
| raise gr.Error("Please upload a control image first.") | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| gen = torch.Generator(device).manual_seed(seed) | |
| w, h = control_image.size | |
| preprocessed = _make_preview( | |
| control_image, mode, canny_threshold_1, canny_threshold_2 | |
| ) | |
| result = pipe( | |
| prompt=prompt, | |
| control_image=[preprocessed], | |
| control_mode=[MODE_MAPPING[mode]], | |
| width=w, | |
| height=h, | |
| controlnet_conditioning_scale=[control_strength], | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=gen, | |
| ).images[0] | |
| return result, seed, preprocessed | |
| # -------------------------------------------------- | |
| # Gradio UI | |
| # -------------------------------------------------- | |
| css = """#wrapper {max-width: 960px; margin: 0 auto;}""" | |
| with gr.Blocks(css=css, elem_id="wrapper") as demo: | |
| gr.Markdown("## FLUX.1‑dev‑ControlNet‑Union‑Pro by Frank") | |
| gr.Markdown( | |
| "A unified ControlNet for **FLUX.1‑dev** from the InstantX team and Shakker Labs. " | |
| + "Recommended strengths: *canny 0.76*. Long prompts usually help." | |
| ) | |
| # ------------ Image panel row ------------ | |
| with gr.Row(): | |
| control_image = gr.Image( | |
| label="Upload animage", | |
| type="pil", | |
| height=512 + 256, | |
| ) | |
| result_image = gr.Image(label="Result", height=512 + 256) | |
| preview_image = gr.Image(label="Pre‑processed Cond", height=512 + 256) | |
| # ------------ Prompt ------------ | |
| prompt_txt = gr.Textbox(label="Prompt", value="White background", lines=1) | |
| # ------------ ControlNet settings ------------ | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### ControlNet") | |
| mode_radio = gr.Radio( | |
| choices=list(MODE_MAPPING.keys()), value="canny", label="Mode" | |
| ) | |
| strength_slider = gr.Slider( | |
| 0.0, 1.0, value=0.76, step=0.01, label="control strength" | |
| ) | |
| gr.Markdown("### Preprocess") | |
| canny_threshold_1 = gr.Slider( | |
| 0, 500, step=1, value=100, label="Canny threshold 1" | |
| ) | |
| canny_threshold_2 = gr.Slider( | |
| 0, 500, step=1, value=200, label="Canny threshold 2" | |
| ) | |
| with gr.Column(): | |
| seed_slider = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed") | |
| randomize_chk = gr.Checkbox(label="Randomize seed", value=False) | |
| guidance_slider = gr.Slider( | |
| 0.0, 10.0, step=0.1, value=3.5, label="Guidance scale" | |
| ) | |
| steps_slider = gr.Slider(1, 50, step=1, value=50, label="Inference steps") | |
| submit_btn = gr.Button("Submit") | |
| submit_btn.click( | |
| fn=infer, | |
| inputs=[ | |
| control_image, | |
| prompt_txt, | |
| mode_radio, | |
| strength_slider, | |
| seed_slider, | |
| randomize_chk, | |
| guidance_slider, | |
| steps_slider, | |
| canny_threshold_1, | |
| canny_threshold_2, | |
| ], | |
| outputs=[result_image, seed_slider, preview_image], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |