Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from src.transformer import SymmetricTransformer2DModel | |
| from src.pipeline import UnifiedPipeline | |
| from src.scheduler import Scheduler | |
| from torchvision import transforms | |
| from transformers import CLIPTextModelWithProjection, CLIPTokenizer | |
| from diffusers import VQModel | |
| import os | |
| from PIL import Image | |
| import numpy as np | |
| import spaces | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def load_models(model_path="MeissonFlow/Meissonic", | |
| transformer_path="MeissonFlow/Muddit"): | |
| model = SymmetricTransformer2DModel.from_pretrained( | |
| transformer_path, | |
| subfolder="1024/transformer" | |
| ) | |
| vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae") | |
| text_encoder = CLIPTextModelWithProjection.from_pretrained(model_path, subfolder="text_encoder") | |
| tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer") | |
| scheduler = Scheduler.from_pretrained(model_path, subfolder="scheduler") | |
| pipe = UnifiedPipeline( | |
| vqvae=vq_model, | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| transformer=model, | |
| scheduler=scheduler, | |
| ) | |
| return pipe | |
| # Load models (global variable to avoid reloading) | |
| pipe = load_models() | |
| pipe.to(device) | |
| # Common transform | |
| def get_transform(resolution): | |
| return transforms.Compose([ | |
| transforms.Resize((resolution, resolution)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Image-to-Text Function | |
| def image_to_text(image, prompt, seed=42, steps=64, cfg=9.0): | |
| try: | |
| resolution = 1024 | |
| transform = get_transform(resolution) | |
| if image is not None: | |
| pil_image = Image.fromarray(image.astype('uint8'), 'RGB') if isinstance(image, np.ndarray) else image | |
| images = torch.stack([transform(pil_image)]) | |
| questions = [prompt] if prompt else ["Please describe this image."] | |
| else: | |
| images = None | |
| questions = [prompt] if prompt else ["Please generate an image description."] | |
| output = pipe( | |
| prompt=questions, | |
| image=images, | |
| height=resolution, | |
| width=resolution, | |
| guidance_scale=cfg, | |
| num_inference_steps=steps, | |
| mask_token_embedding="./mask_token_embedding.pth", | |
| generator=torch.manual_seed(seed), | |
| ) | |
| return output.prompts[0] | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Text-to-Image Function | |
| def text_to_image(prompt, negative_prompt, num_images=1, seed=42, steps=64, cfg=9.0): | |
| try: | |
| resolution = 1024 | |
| negative_prompt = negative_prompt or "worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark" | |
| output = pipe( | |
| prompt=[prompt]*num_images, | |
| negative_prompt=[negative_prompt]*num_images, | |
| height=resolution, | |
| width=resolution, | |
| guidance_scale=cfg, | |
| num_inference_steps=steps, | |
| mask_token_embedding="./mask_token_embedding.pth", | |
| generator=torch.manual_seed(seed), | |
| ) | |
| return output.images | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| return None | |
| # Create Gradio interface with Soft theme | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Muddit Unifined Model") as demo: | |
| gr.Markdown("# π Muddit: Liberating Generation Beyond Text-to-Image with a Unified Discrete Diffusion Model.") | |
| gr.Markdown(" Muddit is a unified discrete diffusion transformer that enables fast and parallel generation across both text and image modalities.") | |
| with gr.Tab("Image to Text"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| i2t_image_input = gr.Image(label="Upload Image", type="pil") | |
| i2t_prompt_input = gr.Textbox(label="Prompt", value="Please describe this image.", placeholder="Enter your prompt here...") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=2**32 - 1, step=1, value=42) | |
| i2t_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=64, step=1) | |
| i2t_cfg = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=9.0, step=0.5) | |
| i2t_submit_btn = gr.Button("Generate Description", variant="primary") | |
| with gr.Column(): | |
| i2t_output_text = gr.Textbox(label="Generated Description", interactive=False) | |
| i2t_examples = gr.Examples( | |
| examples=[ | |
| ["assets/man.jpg"], | |
| ["assets/tennis.jpg"], | |
| ["assets/pizza2.jpg"], | |
| ["assets/plane.jpg"], | |
| ["assets/zebra.jpg"], | |
| ["assets/building.jpg"], | |
| ["assets/flower.jpg"], | |
| ], | |
| inputs=[i2t_image_input], | |
| label="Example Inputs" | |
| ) | |
| with gr.Tab("VQA"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| vqa_image_input = gr.Image(label="Upload Image", type="pil") | |
| vqa_prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your question here...") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=2**32 - 1, step=1, value=42) | |
| vqa_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=64, step=1) | |
| vqa_cfg = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=9.0, step=0.5) | |
| vqa_submit_btn = gr.Button("Generate Answer", variant="primary") | |
| with gr.Column(): | |
| vqa_output_text = gr.Textbox(label="Generated Answer", interactive=False) | |
| vqa_examples = gr.Examples( | |
| examples=[ | |
| ["assets/kid.jpg", "What color is the kid's hair?"], | |
| ["assets/street.jpg", "Can someone legally walk across the street right now?"], | |
| ["assets/dog.jpg", "Where is the dog laying?"], | |
| ["assets/dog2.jpg", "What color is the toy the dog is holding?"], | |
| ["assets/pizza.jpg", "What food item is shown?"], | |
| ["assets/sheep.jpg", "How many sheep are pictured?"], | |
| ["assets/car.jpg", "Where are the cars?"], | |
| ], | |
| inputs=[vqa_image_input, vqa_prompt_input], | |
| label="Example Inputs" | |
| ) | |
| with gr.Tab("Text to Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| t2i_prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate...") | |
| t2i_negative_prompt = gr.Textbox(label="Negative Prompt", | |
| value="worst quality, low quality, low res, blurry, distortion, watermark, logo, signature, text, jpeg artifacts, signature, sketch, duplicate, ugly, identifying mark", | |
| placeholder="What you don't want in the image...", | |
| lines=5) | |
| t2i_num_images = gr.Slider(label="Number of Images", minimum=1, maximum=4, value=1, step=1) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider(label="Seed", minimum=0, maximum=2**32 - 1, step=1, value=42) | |
| t2i_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, value=64, step=1) | |
| t2i_cfg = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=9.0, step=0.5) | |
| t2i_submit_btn = gr.Button("Generate Images", variant="primary") | |
| with gr.Column(): | |
| t2i_gallery = gr.Gallery(label="Generated Images") | |
| t2i_examples = gr.Examples( | |
| examples=[ | |
| ["A line art portrait showcasing a human figure with flowing, textured strokes"], | |
| ["A hyper realistic image of a chimpanzee with a glass-enclosed brain on his head, standing amidst lush, bioluminescent foliage in a vibrant futuristic forest"], | |
| ["A samurai in a stylized cyberpunk outfit adorned with intricate steampunk gear and floral accents, his Mandalorian armor gleaming under the backlighting"], | |
| ["A translucent, minimalist Porsche 911 GT3RS built from sleek carbon fiber, its aerodynamic body designed in the spirit of '60s Braun and modern Apple minimalism"], | |
| ["A realistic photograph of a ramadan tent shaped like a crescent moon under a velvety back sky studded with the milky way"], | |
| ["A portrait of John Lennon, captured in the gritty detail of line art"], | |
| ["In a world plunged into an unending darkness, remnants of fading starlight pierce through a heavy, smog-filled sky"] | |
| ], | |
| inputs=[t2i_prompt_input], | |
| label="Example Prompts" | |
| ) | |
| # Event handlers | |
| i2t_submit_btn.click( | |
| fn=image_to_text, | |
| inputs=[i2t_image_input, i2t_prompt_input, seed, i2t_steps, i2t_cfg], | |
| outputs=i2t_output_text | |
| ) | |
| vqa_submit_btn.click( | |
| fn=image_to_text, | |
| inputs=[vqa_image_input, vqa_prompt_input, seed, vqa_steps, vqa_cfg], | |
| outputs=vqa_output_text | |
| ) | |
| t2i_submit_btn.click( | |
| fn=text_to_image, | |
| inputs=[t2i_prompt_input, t2i_negative_prompt, t2i_num_images, seed, t2i_steps, t2i_cfg], | |
| outputs=t2i_gallery | |
| ) | |
| demo.launch() |