Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import torch | |
| from diffusers import AutoPipelineForInpainting | |
| from PIL import Image, ImageFilter | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BlipForConditionalGeneration, | |
| BlipProcessor, | |
| Owlv2ForObjectDetection, | |
| Owlv2Processor, | |
| SamModel, | |
| SamProcessor, | |
| ) | |
| def delete_model(model): | |
| model.to("cpu") | |
| del model | |
| torch.cuda.empty_cache() | |
| def run_language_model(edit_prompt, caption, device): | |
| language_model_id = "Qwen/Qwen1.5-0.5B-Chat" | |
| language_model = AutoModelForCausalLM.from_pretrained( | |
| language_model_id, device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(language_model_id) | |
| messages = [ | |
| {"role": "system", "content": "Follow the examples and return the expected output"}, | |
| {"role": "user", "content": "Caption: a blue sky with fluffy clouds\nQuery: Make the sky stormy"}, | |
| {"role": "assistant", "content": "A: sky\nB: a stormy sky with heavy gray clouds, torrential rain, gloomy, overcast"}, | |
| {"role": "user", "content": "Caption: a cat sleeping on a sofa\nQuery: Change the cat to a dog"}, | |
| {"role": "assistant", "content": "A: cat\nB: a dog sleeping on a sofa, cozy and comfortable, snuggled up in a warm blanket, peaceful"}, | |
| {"role": "user", "content": "Caption: a snowy mountain peak\nQuery: Replace the snow with greenery"}, | |
| {"role": "assistant", "content": "A: snow\nB: a lush green mountain peak in summer, clear blue skies, birds flying overhead, serene and majestic"}, | |
| {"role": "user", "content": "Caption: a vintage car parked by the roadside\nQuery: Change the car to a modern electric vehicle"}, | |
| {"role": "assistant", "content": "A: car\nB: a sleek modern electric vehicle parked by the roadside, cutting-edge design, environmentally friendly, silent and powerful"}, | |
| {"role": "user", "content": "Caption: a wooden bridge over a river\nQuery: Make the bridge stone"}, | |
| {"role": "assistant", "content": "A: bridge\nB: an ancient stone bridge over a river, moss-covered, sturdy and timeless, with clear waters flowing beneath"}, | |
| {"role": "user", "content": "Caption: a bowl of salad on the table\nQuery: Replace salad with soup"}, | |
| {"role": "assistant", "content": "A: bowl\nB: a bowl of steaming hot soup on the table, scrumptious, with garnishing"}, | |
| {"role": "user", "content": "Caption: a book on a desk surrounded by stationery\nQuery: Remove all stationery, add a laptop"}, | |
| {"role": "assistant", "content": "A: stationery\nB: a book on a desk with a laptop next to it, modern study setup, focused and productive, technology and education combined"}, | |
| {"role": "user", "content": "Caption: a cup of coffee on a wooden table\nQuery: Change coffee to tea"}, | |
| {"role": "assistant", "content": "A: cup\nB: a steaming cup of tea on a wooden table, calming and aromatic, with a slice of lemon on the side, inviting"}, | |
| {"role": "user", "content": "Caption: a small pen on a white table\nQuery: Change the pen to an elaborate fountain pen"}, | |
| {"role": "assistant", "content": "A: pen\nB: an elaborate fountain pen on a white table, sleek and elegant, with intricate designs, ready for writing"}, | |
| {"role": "user", "content": "Caption: a plain notebook on a desk\nQuery: Replace the notebook with a journal"}, | |
| {"role": "assistant", "content": "A: notebook\nB: an artistically decorated journal on a desk, vibrant cover, filled with creativity, inspiring and personalized"}, | |
| {"role": "user", "content": f"Caption: {caption}\nQuery: {edit_prompt}"}, | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| generated_ids = language_model.generate( | |
| model_inputs.input_ids, | |
| max_new_tokens=512, | |
| temperature=0.0, | |
| do_sample=False | |
| ) | |
| generated_ids = [ | |
| output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| output_generation_a, output_generation_b = response.split("\n") | |
| to_replace = output_generation_a[2:].strip() | |
| replaced_caption = output_generation_b[2:].strip() | |
| delete_model(language_model) | |
| return (to_replace, replaced_caption) | |
| def run_image_captioner(image, device): | |
| caption_model_id = "Salesforce/blip-image-captioning-base" | |
| caption_model = BlipForConditionalGeneration.from_pretrained(caption_model_id).to( | |
| device | |
| ) | |
| caption_processor = BlipProcessor.from_pretrained(caption_model_id) | |
| inputs = caption_processor(image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = caption_model.generate(**inputs, max_new_tokens=200) | |
| caption = caption_processor.decode(outputs[0], skip_special_tokens=True) | |
| delete_model(caption_model) | |
| return caption | |
| def run_segmentation(image, object_to_segment, device): | |
| # OWL-V2 for object detection | |
| owl_v2_model_id = "google/owlv2-base-patch16-ensemble" | |
| processor = Owlv2Processor.from_pretrained(owl_v2_model_id) | |
| od_model = Owlv2ForObjectDetection.from_pretrained(owl_v2_model_id).to(device) | |
| text_queries = [object_to_segment] | |
| inputs = processor(text=text_queries, images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = od_model(**inputs) | |
| target_sizes = torch.tensor([image.size]).to(device) | |
| results = processor.post_process_object_detection( | |
| outputs, threshold=0.1, target_sizes=target_sizes | |
| )[0] | |
| boxes = results["boxes"].tolist() | |
| delete_model(od_model) | |
| # SAM for image segmentation | |
| sam_model_id = "facebook/sam-vit-base" | |
| seg_model = SamModel.from_pretrained(sam_model_id).to(device) | |
| processor = SamProcessor.from_pretrained(sam_model_id) | |
| input_boxes = [boxes] | |
| inputs = processor(image, input_boxes=input_boxes, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = seg_model(**inputs) | |
| masks = processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs["original_sizes"].cpu(), | |
| inputs["reshaped_input_sizes"].cpu(), | |
| )[0] | |
| # Merge the masks | |
| masks = torch.max(masks[:, 0, ...], dim=0, keepdim=False).values | |
| delete_model(seg_model) | |
| return masks | |
| def run_inpainting(image, replaced_caption, masks, generator, device): | |
| pipeline = AutoPipelineForInpainting.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| ).to(device) | |
| masks = Image.fromarray(masks.numpy()) | |
| # dilation_image = masks.filter(ImageFilter.MaxFilter(3)) | |
| prompt = replaced_caption | |
| negative_prompt = """lowres, bad anatomy, bad hands, | |
| text, error, missing fingers, extra digit, fewer digits, | |
| cropped, worst quality, low quality""" | |
| output = pipeline( | |
| prompt=prompt, | |
| image=image, | |
| # mask_image=dilation_image, | |
| mask_image=masks, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=7.5, | |
| strength=1.0, | |
| generator=generator, | |
| ).images[0] | |
| delete_model(pipeline) | |
| return output | |
| def run_open_gen_fill(image, edit_prompt): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Resize the image to (512, 512) | |
| image = image.resize((512, 512)) | |
| # Caption the input image | |
| caption = run_image_captioner(image, device=device) | |
| # Run the langauge model to extract the object for segmentation | |
| # and get the replaced caption | |
| to_replace, replaced_caption = run_language_model( | |
| edit_prompt=edit_prompt, caption=caption, device=device | |
| ) | |
| # Segment the `to_replace` object from the input image | |
| masks = run_segmentation(image, to_replace, device=device) | |
| # Diffusion pipeline for inpainting | |
| generator = torch.Generator(device).manual_seed(17) | |
| output = run_inpainting( | |
| image=image, replaced_caption=replaced_caption, masks=masks, generator=generator, device=device | |
| ) | |
| return ( | |
| to_replace, | |
| caption, | |
| replaced_caption, | |
| Image.fromarray(masks.numpy()), | |
| output, | |
| ) | |
| def setup_gradio_interface(): | |
| block = gr.Blocks() | |
| with block: | |
| gr.Markdown("<h1><center>Open Generative Fill V1<h1><center>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image_placeholder = gr.Image(type="pil", label="Input Image") | |
| edit_prompt_placeholder = gr.Textbox(label="Enter the editing prompt") | |
| run_button_placeholder = gr.Button(value="Run") | |
| with gr.Column(): | |
| to_replace_placeholder = gr.Textbox(label="to_replace") | |
| image_caption_placeholder = gr.Textbox(label="Image Caption") | |
| replace_caption_placeholder = gr.Textbox(label="Replaced Caption") | |
| segmentation_placeholder = gr.Image(type="pil", label="Segmentation") | |
| output_image_placeholder = gr.Image(type="pil", label="Output Image") | |
| run_button_placeholder.click( | |
| fn=lambda image, edit_prompt: run_open_gen_fill( | |
| image=image, | |
| edit_prompt=edit_prompt, | |
| ), | |
| inputs=[input_image_placeholder, edit_prompt_placeholder], | |
| outputs=[ | |
| to_replace_placeholder, | |
| image_caption_placeholder, | |
| replace_caption_placeholder, | |
| segmentation_placeholder, | |
| output_image_placeholder, | |
| ], | |
| ) | |
| gr.Examples( | |
| examples=[["dog.jpeg", "replace the dog with a tiger"]], | |
| inputs=[input_image_placeholder, edit_prompt_placeholder], | |
| outputs=[ | |
| to_replace_placeholder, | |
| image_caption_placeholder, | |
| replace_caption_placeholder, | |
| segmentation_placeholder, | |
| output_image_placeholder, | |
| ], | |
| fn=lambda image, edit_prompt: run_open_gen_fill( | |
| image=image, | |
| edit_prompt=edit_prompt, | |
| ), | |
| cache_examples=True, | |
| label="Try this example input!", | |
| ) | |
| return block | |
| if __name__ == "__main__": | |
| gradio_interface = setup_gradio_interface() | |
| # gradio_interface.queue(max_size=10) | |
| gradio_interface.launch(share=False, show_api=False, show_error=True) | |