Spaces:
Runtime error
Runtime error
| from typing import TypedDict | |
| import diffusers.image_processor | |
| import gradio as gr | |
| import pillow_heif | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation | |
| from pipeline import TryOffAnyone | |
| import numpy as np | |
| pillow_heif.register_heif_opener() | |
| pillow_heif.register_avif_opener() | |
| torch.set_float32_matmul_precision("high") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| TITLE = """ | |
| # Try Off Anyone | |
| ## Important | |
| 1. Choose an example image or upload your own | |
| [[arxiv:2412.08573]](https://arxiv.org/abs/2412.08573) | |
| [[github:ixarchakos/try-off-anyone]](https://github.com/ixarchakos/try-off-anyone) | |
| """ | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| DTYPE = torch.bfloat16 if DEVICE == 'cuda' else torch.float32 | |
| pipeline_tryoff = TryOffAnyone( | |
| device=DEVICE, | |
| dtype=DTYPE, | |
| ) | |
| mask_processor = diffusers.image_processor.VaeImageProcessor( | |
| vae_scale_factor=8, | |
| do_normalize=False, | |
| do_binarize=True, | |
| do_convert_grayscale=True, | |
| ) | |
| vae_processor = diffusers.image_processor.VaeImageProcessor( | |
| vae_scale_factor=8, | |
| ) | |
| def mask_generation(image, processor, model, category): | |
| inputs = processor(images=image, return_tensors="pt").to("cuda") | |
| outputs = model(**inputs) | |
| logits = outputs.logits.cpu() | |
| upsampled_logits = torch.nn.functional.interpolate( | |
| logits, | |
| size=image.size[::-1], | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| predicted_mask = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy() | |
| if category == "Tops": | |
| predicted_mask_1 = predicted_mask == 4 | |
| predicted_mask_2 = predicted_mask == 7 | |
| elif category == "Bottoms": | |
| predicted_mask_1 = predicted_mask == 5 | |
| predicted_mask_2 = predicted_mask == 6 | |
| else: | |
| raise NotImplementedError | |
| predicted_mask = predicted_mask_1 + predicted_mask_2 | |
| mask_image = Image.fromarray((predicted_mask * 255).astype(np.uint8)) | |
| return mask_image | |
| class ImageData(TypedDict): | |
| background: Image.Image | |
| composite: Image.Image | |
| layers: list[Image.Image] | |
| def process( | |
| image_data: ImageData, | |
| image_width: int, | |
| image_height: int, | |
| num_inference_steps: int, | |
| condition_scale: float, | |
| seed: int, | |
| ) -> Image.Image: | |
| assert image_width > 0 | |
| assert image_height > 0 | |
| assert num_inference_steps > 0 | |
| assert condition_scale > 0 | |
| assert seed >= 0 | |
| # extract image and mask from image_data | |
| image = image_data["background"] | |
| processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer_b3_clothes") | |
| model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer_b3_clothes") | |
| model.to("cuda") | |
| # preprocess image | |
| image = image.convert("RGB").resize((image_width, image_height)) | |
| mask = mask_generation(image, processor, model, "Tops") | |
| image_preprocessed = vae_processor.preprocess( | |
| image=image, | |
| width=image_width, | |
| height=image_height, | |
| )[0] | |
| # preprocess mask | |
| mask = mask.resize((image_width, image_height)) | |
| mask_preprocessed = mask_processor.preprocess( # pyright: ignore[reportUnknownMemberType] | |
| image=mask, | |
| width=image_width, | |
| height=image_height, | |
| )[0] | |
| # generate the TryOff image | |
| gen = torch.Generator(device=DEVICE).manual_seed(seed) | |
| tryoff_image = pipeline_tryoff( | |
| image_preprocessed, | |
| mask_preprocessed, | |
| inference_steps=num_inference_steps, | |
| scale=condition_scale, | |
| generator=gen, | |
| )[0] | |
| return tryoff_image | |
| with gr.Blocks() as demo: | |
| gr.Markdown(TITLE) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.ImageMask( | |
| label="Input Image", | |
| height=1024, | |
| type="pil", | |
| interactive=True, | |
| ) | |
| run_button = gr.Button( | |
| value="Extract Clothing", | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/model_1.jpg"], | |
| ["examples/model_2.jpg"], | |
| ["examples/model_3.jpg"], | |
| ["examples/model_4.jpg"], | |
| ["examples/model_5.jpg"], | |
| ["examples/model_6.jpg"], | |
| ["examples/model_7.jpg"], | |
| ["examples/model_8.jpg"], | |
| ["examples/model_9.jpg"], | |
| ], | |
| inputs=[input_image], | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="TryOff result", | |
| height=1024, | |
| image_mode="RGB", | |
| type="pil", | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=36, | |
| maximum=36, | |
| value=36, | |
| step=1, | |
| ) | |
| scale = gr.Slider( | |
| label="Scale", | |
| minimum=2.5, | |
| maximum=2.5, | |
| value=2.5, | |
| step=0, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=50, | |
| maximum=50, | |
| value=50, | |
| step=1, | |
| ) | |
| with gr.Row(): | |
| image_width = gr.Slider( | |
| label="Image Width", | |
| minimum=384, | |
| maximum=384, | |
| value=384, | |
| step=8, | |
| ) | |
| image_height = gr.Slider( | |
| label="Image Height", | |
| minimum=512, | |
| maximum=512, | |
| value=512, | |
| step=8, | |
| ) | |
| run_button.click( | |
| fn=process, | |
| inputs=[ | |
| input_image, | |
| image_width, | |
| image_height, | |
| num_inference_steps, | |
| scale, | |
| seed, | |
| ], | |
| outputs=output_image, | |
| ) | |
| demo.launch() | |