Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Create app.py
Browse files
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,96 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gradio as gr
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            from diffusers.utils import load_image
         
     | 
| 4 | 
         
            +
            from controlnet_flux import FluxControlNetModel
         
     | 
| 5 | 
         
            +
            from transformer_flux import FluxTransformer2DModel
         
     | 
| 6 | 
         
            +
            from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
         
     | 
| 7 | 
         
            +
            from PIL import Image, ImageDraw
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            # Load models
         
     | 
| 10 | 
         
            +
            controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16)
         
     | 
| 11 | 
         
            +
            transformer = FluxTransformer2DModel.from_pretrained(
         
     | 
| 12 | 
         
            +
                "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dtype=torch.bfloat16
         
     | 
| 13 | 
         
            +
            )
         
     | 
| 14 | 
         
            +
            pipe = FluxControlNetInpaintingPipeline.from_pretrained(
         
     | 
| 15 | 
         
            +
                "black-forest-labs/FLUX.1-dev",
         
     | 
| 16 | 
         
            +
                controlnet=controlnet,
         
     | 
| 17 | 
         
            +
                transformer=transformer,
         
     | 
| 18 | 
         
            +
                torch_dtype=torch.bfloat16
         
     | 
| 19 | 
         
            +
            ).to("cuda")
         
     | 
| 20 | 
         
            +
            pipe.transformer.to(torch.bfloat16)
         
     | 
| 21 | 
         
            +
            pipe.controlnet.to(torch.bfloat16)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def prepare_image_and_mask(image, width, height, overlap_percentage):
         
     | 
| 24 | 
         
            +
                # Resize the input image to fit within the target size
         
     | 
| 25 | 
         
            +
                image.thumbnail((width, height), Image.LANCZOS)
         
     | 
| 26 | 
         
            +
                
         
     | 
| 27 | 
         
            +
                # Create a new white background image of the target size
         
     | 
| 28 | 
         
            +
                background = Image.new('RGB', (width, height), (255, 255, 255))
         
     | 
| 29 | 
         
            +
                
         
     | 
| 30 | 
         
            +
                # Paste the resized image onto the background
         
     | 
| 31 | 
         
            +
                offset = ((width - image.width) // 2, (height - image.height) // 2)
         
     | 
| 32 | 
         
            +
                background.paste(image, offset)
         
     | 
| 33 | 
         
            +
                
         
     | 
| 34 | 
         
            +
                # Create a mask
         
     | 
| 35 | 
         
            +
                mask = Image.new('L', (width, height), 255)
         
     | 
| 36 | 
         
            +
                draw = ImageDraw.Draw(mask)
         
     | 
| 37 | 
         
            +
                
         
     | 
| 38 | 
         
            +
                # Calculate the overlap area
         
     | 
| 39 | 
         
            +
                overlap_x = int(image.width * overlap_percentage / 100)
         
     | 
| 40 | 
         
            +
                overlap_y = int(image.height * overlap_percentage / 100)
         
     | 
| 41 | 
         
            +
                
         
     | 
| 42 | 
         
            +
                # Draw the mask (black area is where we want to inpaint)
         
     | 
| 43 | 
         
            +
                draw.rectangle([
         
     | 
| 44 | 
         
            +
                    (offset[0] + overlap_x, offset[1] + overlap_y),
         
     | 
| 45 | 
         
            +
                    (offset[0] + image.width - overlap_x, offset[1] + image.height - overlap_y)
         
     | 
| 46 | 
         
            +
                ], fill=0)
         
     | 
| 47 | 
         
            +
                
         
     | 
| 48 | 
         
            +
                return background, mask
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            def inpaint(image, prompt, width, height, overlap_percentage, num_inference_steps, guidance_scale):
         
     | 
| 51 | 
         
            +
                # Prepare image and mask
         
     | 
| 52 | 
         
            +
                image, mask = prepare_image_and_mask(image, width, height, overlap_percentage)
         
     | 
| 53 | 
         
            +
                
         
     | 
| 54 | 
         
            +
                # Set up generator for reproducibility
         
     | 
| 55 | 
         
            +
                generator = torch.Generator(device="cuda").manual_seed(42)
         
     | 
| 56 | 
         
            +
                
         
     | 
| 57 | 
         
            +
                # Run inpainting
         
     | 
| 58 | 
         
            +
                result = pipe(
         
     | 
| 59 | 
         
            +
                    prompt=prompt,
         
     | 
| 60 | 
         
            +
                    height=height,
         
     | 
| 61 | 
         
            +
                    width=width,
         
     | 
| 62 | 
         
            +
                    control_image=image,
         
     | 
| 63 | 
         
            +
                    control_mask=mask,
         
     | 
| 64 | 
         
            +
                    num_inference_steps=num_inference_steps,
         
     | 
| 65 | 
         
            +
                    generator=generator,
         
     | 
| 66 | 
         
            +
                    controlnet_conditioning_scale=0.9,
         
     | 
| 67 | 
         
            +
                    guidance_scale=guidance_scale,
         
     | 
| 68 | 
         
            +
                    negative_prompt="",
         
     | 
| 69 | 
         
            +
                    true_guidance_scale=guidance_scale
         
     | 
| 70 | 
         
            +
                ).images[0]
         
     | 
| 71 | 
         
            +
                
         
     | 
| 72 | 
         
            +
                return result
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            # Gradio interface
         
     | 
| 75 | 
         
            +
            with gr.Blocks() as demo:
         
     | 
| 76 | 
         
            +
                gr.Markdown("# FLUX Outpainting Demo")
         
     | 
| 77 | 
         
            +
                with gr.Row():
         
     | 
| 78 | 
         
            +
                    with gr.Column():
         
     | 
| 79 | 
         
            +
                        input_image = gr.Image(type="pil", label="Input Image")
         
     | 
| 80 | 
         
            +
                        prompt_input = gr.Textbox(label="Prompt")
         
     | 
| 81 | 
         
            +
                        width_slider = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=768)
         
     | 
| 82 | 
         
            +
                        height_slider = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=768)
         
     | 
| 83 | 
         
            +
                        overlap_slider = gr.Slider(label="Overlap Percentage", minimum=0, maximum=50, step=1, value=10)
         
     | 
| 84 | 
         
            +
                        steps_slider = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=28)
         
     | 
| 85 | 
         
            +
                        guidance_slider = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=10.0, step=0.1, value=3.5)
         
     | 
| 86 | 
         
            +
                        run_button = gr.Button("Generate")
         
     | 
| 87 | 
         
            +
                    with gr.Column():
         
     | 
| 88 | 
         
            +
                        output_image = gr.Image(label="Output Image")
         
     | 
| 89 | 
         
            +
                
         
     | 
| 90 | 
         
            +
                run_button.click(
         
     | 
| 91 | 
         
            +
                    fn=inpaint,
         
     | 
| 92 | 
         
            +
                    inputs=[input_image, prompt_input, width_slider, height_slider, overlap_slider, steps_slider, guidance_slider],
         
     | 
| 93 | 
         
            +
                    outputs=output_image
         
     | 
| 94 | 
         
            +
                )
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            demo.launch()
         
     |