Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import torch | |
| import random | |
| import subprocess | |
| import requests | |
| import json | |
| subprocess.run( | |
| "pip install flash-attn --no-build-isolation", | |
| env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| shell=True, | |
| ) | |
| from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights | |
| from PIL import Image | |
| from data.data_utils import add_special_tokens, pil_img2rgb | |
| from data.transforms import ImageTransform | |
| from inferencer import InterleaveInferencer | |
| from modeling.autoencoder import load_ae | |
| from modeling.bagel import ( | |
| BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, | |
| SiglipVisionConfig, SiglipVisionModel | |
| ) | |
| from modeling.qwen2 import Qwen2Tokenizer | |
| from huggingface_hub import snapshot_download | |
| # Get Brave Search API key | |
| BSEARCH_API = os.getenv("BSEARCH_API") | |
| save_dir = "./model_weights" | |
| repo_id = "ByteDance-Seed/BAGEL-7B-MoT" | |
| cache_dir = save_dir + "/cache" | |
| snapshot_download( | |
| cache_dir=cache_dir, | |
| local_dir=save_dir, | |
| repo_id=repo_id, | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"], | |
| ) | |
| # Model Initialization | |
| model_path = save_dir | |
| llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json")) | |
| llm_config.qk_norm = True | |
| llm_config.tie_word_embeddings = False | |
| llm_config.layer_module = "Qwen2MoTDecoderLayer" | |
| vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json")) | |
| vit_config.rope = False | |
| vit_config.num_hidden_layers -= 1 | |
| vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors")) | |
| config = BagelConfig( | |
| visual_gen=True, | |
| visual_und=True, | |
| llm_config=llm_config, | |
| vit_config=vit_config, | |
| vae_config=vae_config, | |
| vit_max_num_patch_per_side=70, | |
| connector_act='gelu_pytorch_tanh', | |
| latent_patch_size=2, | |
| max_latent_size=64, | |
| ) | |
| with init_empty_weights(): | |
| language_model = Qwen2ForCausalLM(llm_config) | |
| vit_model = SiglipVisionModel(vit_config) | |
| model = Bagel(language_model, vit_model, config) | |
| model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True) | |
| tokenizer = Qwen2Tokenizer.from_pretrained(model_path) | |
| tokenizer, new_token_ids, _ = add_special_tokens(tokenizer) | |
| vae_transform = ImageTransform(1024, 512, 16) | |
| vit_transform = ImageTransform(980, 224, 14) | |
| # Model Loading and Multi GPU Infernece Preparing | |
| device_map = infer_auto_device_map( | |
| model, | |
| max_memory={i: "80GiB" for i in range(torch.cuda.device_count())}, | |
| no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"], | |
| ) | |
| same_device_modules = [ | |
| 'language_model.model.embed_tokens', | |
| 'time_embedder', | |
| 'latent_pos_embed', | |
| 'vae2llm', | |
| 'llm2vae', | |
| 'connector', | |
| 'vit_pos_embed' | |
| ] | |
| if torch.cuda.device_count() == 1: | |
| first_device = device_map.get(same_device_modules[0], "cuda:0") | |
| for k in same_device_modules: | |
| if k in device_map: | |
| device_map[k] = first_device | |
| else: | |
| device_map[k] = "cuda:0" | |
| else: | |
| first_device = device_map.get(same_device_modules[0]) | |
| for k in same_device_modules: | |
| if k in device_map: | |
| device_map[k] = first_device | |
| model = load_checkpoint_and_dispatch( | |
| model, | |
| checkpoint=os.path.join(model_path, "ema.safetensors"), | |
| device_map=device_map, | |
| offload_buffers=True, | |
| offload_folder="offload", | |
| dtype=torch.bfloat16, | |
| force_hooks=True, | |
| ).eval() | |
| # Inferencer Preparing | |
| inferencer = InterleaveInferencer( | |
| model=model, | |
| vae_model=vae_model, | |
| tokenizer=tokenizer, | |
| vae_transform=vae_transform, | |
| vit_transform=vit_transform, | |
| new_token_ids=new_token_ids, | |
| ) | |
| # Brave Search function | |
| def brave_search(query): | |
| """Perform a web search using Brave Search API.""" | |
| if not BSEARCH_API: | |
| return None | |
| try: | |
| headers = { | |
| "Accept": "application/json", | |
| "X-Subscription-Token": BSEARCH_API | |
| } | |
| url = "https://api.search.brave.com/res/v1/web/search" | |
| params = { | |
| "q": query, | |
| "count": 5 | |
| } | |
| response = requests.get(url, headers=headers, params=params) | |
| response.raise_for_status() | |
| data = response.json() | |
| results = [] | |
| if "web" in data and "results" in data["web"]: | |
| for idx, result in enumerate(data["web"]["results"][:5], 1): | |
| title = result.get("title", "No title") | |
| url = result.get("url", "") | |
| description = result.get("description", "No description") | |
| results.append(f"{idx}. {title}\nURL: {url}\n{description}") | |
| if results: | |
| return "\n\n".join(results) | |
| else: | |
| return None | |
| except Exception as e: | |
| print(f"Search error: {str(e)}") | |
| return None | |
| def enhance_prompt_with_search(prompt, use_search=False): | |
| """Enhance prompt with web search results if enabled.""" | |
| if not use_search or not BSEARCH_API: | |
| return prompt | |
| search_results = brave_search(prompt) | |
| if search_results: | |
| enhanced_prompt = f"{prompt}\n\n[Web Search Context]:\n{search_results}\n\n[Generate based on the above context and original prompt]" | |
| return enhanced_prompt | |
| return prompt | |
| def set_seed(seed): | |
| """Set random seeds for reproducibility""" | |
| if seed > 0: | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| return seed | |
| # Text to Image function with thinking option and hyperparameters | |
| def text_to_image(prompt, use_web_search=False, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4, | |
| timestep_shift=3.0, num_timesteps=50, | |
| cfg_renorm_min=1.0, cfg_renorm_type="global", | |
| max_think_token_n=1024, do_sample=False, text_temperature=0.3, | |
| seed=0, image_ratio="1:1"): | |
| # Set seed for reproducibility | |
| set_seed(seed) | |
| # Enhance prompt with search if enabled | |
| enhanced_prompt = enhance_prompt_with_search(prompt, use_web_search) | |
| if image_ratio == "1:1": | |
| image_shapes = (1024, 1024) | |
| elif image_ratio == "4:3": | |
| image_shapes = (768, 1024) | |
| elif image_ratio == "3:4": | |
| image_shapes = (1024, 768) | |
| elif image_ratio == "16:9": | |
| image_shapes = (576, 1024) | |
| elif image_ratio == "9:16": | |
| image_shapes = (1024, 576) | |
| # Set hyperparameters | |
| inference_hyper = dict( | |
| max_think_token_n=max_think_token_n if show_thinking else 1024, | |
| do_sample=do_sample if show_thinking else False, | |
| text_temperature=text_temperature if show_thinking else 0.3, | |
| cfg_text_scale=cfg_text_scale, | |
| cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0 | |
| timestep_shift=timestep_shift, | |
| num_timesteps=num_timesteps, | |
| cfg_renorm_min=cfg_renorm_min, | |
| cfg_renorm_type=cfg_renorm_type, | |
| image_shapes=image_shapes, | |
| ) | |
| result = {"text": "", "image": None} | |
| # Call inferencer with or without think parameter based on user choice | |
| for i in inferencer(text=enhanced_prompt, think=show_thinking, understanding_output=False, **inference_hyper): | |
| if type(i) == str: | |
| result["text"] += i | |
| else: | |
| result["image"] = i | |
| yield result["image"], result.get("text", None) | |
| # Image Understanding function with thinking option and hyperparameters | |
| def image_understanding(image: Image.Image, prompt: str, use_web_search=False, show_thinking=False, | |
| do_sample=False, text_temperature=0.3, max_new_tokens=512): | |
| if image is None: | |
| return "Please upload an image." | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| image = pil_img2rgb(image) | |
| # Enhance prompt with search if enabled | |
| enhanced_prompt = enhance_prompt_with_search(prompt, use_web_search) | |
| # Set hyperparameters | |
| inference_hyper = dict( | |
| do_sample=do_sample, | |
| text_temperature=text_temperature, | |
| max_think_token_n=max_new_tokens, # Set max_length | |
| ) | |
| result = {"text": "", "image": None} | |
| # Use show_thinking parameter to control thinking process | |
| for i in inferencer(image=image, text=enhanced_prompt, think=show_thinking, | |
| understanding_output=True, **inference_hyper): | |
| if type(i) == str: | |
| result["text"] += i | |
| else: | |
| result["image"] = i | |
| yield result["text"] | |
| # Image Editing function with thinking option and hyperparameters | |
| def edit_image(image: Image.Image, prompt: str, use_web_search=False, show_thinking=False, cfg_text_scale=4.0, | |
| cfg_img_scale=2.0, cfg_interval=0.0, | |
| timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0, | |
| cfg_renorm_type="text_channel", max_think_token_n=1024, | |
| do_sample=False, text_temperature=0.3, seed=0): | |
| # Set seed for reproducibility | |
| set_seed(seed) | |
| if image is None: | |
| return "Please upload an image.", "" | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| image = pil_img2rgb(image) | |
| # Enhance prompt with search if enabled | |
| enhanced_prompt = enhance_prompt_with_search(prompt, use_web_search) | |
| # Set hyperparameters | |
| inference_hyper = dict( | |
| max_think_token_n=max_think_token_n if show_thinking else 1024, | |
| do_sample=do_sample if show_thinking else False, | |
| text_temperature=text_temperature if show_thinking else 0.3, | |
| cfg_text_scale=cfg_text_scale, | |
| cfg_img_scale=cfg_img_scale, | |
| cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0 | |
| timestep_shift=timestep_shift, | |
| num_timesteps=num_timesteps, | |
| cfg_renorm_min=cfg_renorm_min, | |
| cfg_renorm_type=cfg_renorm_type, | |
| ) | |
| # Include thinking parameter based on user choice | |
| result = {"text": "", "image": None} | |
| for i in inferencer(image=image, text=enhanced_prompt, think=show_thinking, understanding_output=False, **inference_hyper): | |
| if type(i) == str: | |
| result["text"] += i | |
| else: | |
| result["image"] = i | |
| yield result["image"], result.get("text", "") | |
| # Helper function to load example images | |
| def load_example_image(image_path): | |
| try: | |
| return Image.open(image_path) | |
| except Exception as e: | |
| print(f"Error loading example image: {e}") | |
| return None | |
| # Enhanced CSS for visual improvements | |
| custom_css = """ | |
| /* Modern gradient background */ | |
| .gradio-container { | |
| background: linear-gradient(135deg, #1e3c72 0%, #2a5298 50%, #3a6fb0 100%); | |
| min-height: 100vh; | |
| } | |
| /* Main container with glassmorphism */ | |
| .container { | |
| backdrop-filter: blur(10px); | |
| background: rgba(255, 255, 255, 0.1); | |
| border-radius: 20px; | |
| padding: 30px; | |
| margin: 20px auto; | |
| max-width: 1400px; | |
| box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2); | |
| } | |
| /* Header styling */ | |
| h1 { | |
| background: linear-gradient(90deg, #ffffff 0%, #e0e0e0 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| font-size: 3.5em; | |
| text-align: center; | |
| margin-bottom: 30px; | |
| font-weight: 800; | |
| text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.3); | |
| } | |
| /* Tab styling */ | |
| .tabs { | |
| background: rgba(255, 255, 255, 0.15); | |
| border-radius: 15px; | |
| padding: 10px; | |
| margin-bottom: 20px; | |
| } | |
| .tab-nav { | |
| background: rgba(255, 255, 255, 0.2) !important; | |
| border-radius: 10px !important; | |
| padding: 5px !important; | |
| } | |
| .tab-nav button { | |
| background: transparent !important; | |
| color: white !important; | |
| border: none !important; | |
| padding: 10px 20px !important; | |
| margin: 0 5px !important; | |
| border-radius: 8px !important; | |
| font-weight: 600 !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .tab-nav button.selected { | |
| background: rgba(255, 255, 255, 0.3) !important; | |
| box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2) !important; | |
| } | |
| .tab-nav button:hover { | |
| background: rgba(255, 255, 255, 0.25) !important; | |
| } | |
| /* Input field styling */ | |
| .textbox, .image-container { | |
| background: rgba(255, 255, 255, 0.95) !important; | |
| border: 2px solid rgba(255, 255, 255, 0.3) !important; | |
| border-radius: 12px !important; | |
| padding: 15px !important; | |
| color: #333 !important; | |
| font-size: 16px !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .textbox:focus { | |
| border-color: #3a6fb0 !important; | |
| box-shadow: 0 0 20px rgba(58, 111, 176, 0.4) !important; | |
| } | |
| /* Button styling */ | |
| .primary { | |
| background: linear-gradient(135deg, #4CAF50 0%, #45a049 100%) !important; | |
| color: white !important; | |
| border: none !important; | |
| padding: 12px 30px !important; | |
| border-radius: 10px !important; | |
| font-weight: 600 !important; | |
| font-size: 16px !important; | |
| cursor: pointer !important; | |
| transition: all 0.3s ease !important; | |
| box-shadow: 0 4px 15px rgba(76, 175, 80, 0.3) !important; | |
| } | |
| .primary:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 6px 20px rgba(76, 175, 80, 0.4) !important; | |
| } | |
| /* Checkbox styling */ | |
| .checkbox-group { | |
| background: rgba(255, 255, 255, 0.1) !important; | |
| padding: 10px 15px !important; | |
| border-radius: 8px !important; | |
| margin: 10px 0 !important; | |
| } | |
| .checkbox-group label { | |
| color: white !important; | |
| font-weight: 500 !important; | |
| } | |
| /* Accordion styling */ | |
| .accordion { | |
| background: rgba(255, 255, 255, 0.1) !important; | |
| border-radius: 12px !important; | |
| margin: 15px 0 !important; | |
| border: 1px solid rgba(255, 255, 255, 0.2) !important; | |
| } | |
| .accordion-header { | |
| background: rgba(255, 255, 255, 0.15) !important; | |
| color: white !important; | |
| padding: 12px 20px !important; | |
| border-radius: 10px !important; | |
| font-weight: 600 !important; | |
| } | |
| /* Slider styling */ | |
| .slider { | |
| background: rgba(255, 255, 255, 0.2) !important; | |
| border-radius: 5px !important; | |
| } | |
| .slider .handle { | |
| background: white !important; | |
| border: 3px solid #3a6fb0 !important; | |
| } | |
| /* Image output styling */ | |
| .image-frame { | |
| border-radius: 15px !important; | |
| overflow: hidden !important; | |
| box-shadow: 0 8px 25px rgba(0, 0, 0, 0.3) !important; | |
| background: rgba(255, 255, 255, 0.1) !important; | |
| padding: 10px !important; | |
| } | |
| /* Footer links */ | |
| a { | |
| color: #64b5f6 !important; | |
| text-decoration: none !important; | |
| font-weight: 500 !important; | |
| transition: color 0.3s ease !important; | |
| } | |
| a:hover { | |
| color: #90caf9 !important; | |
| } | |
| /* Web search info box */ | |
| .web-search-info { | |
| background: linear-gradient(135deg, rgba(255, 193, 7, 0.2) 0%, rgba(255, 152, 0, 0.2) 100%); | |
| border: 2px solid rgba(255, 193, 7, 0.5); | |
| border-radius: 10px; | |
| padding: 15px; | |
| margin: 10px 0; | |
| color: white; | |
| } | |
| .web-search-info h4 { | |
| margin: 0 0 10px 0; | |
| color: #ffd54f; | |
| font-size: 1.2em; | |
| } | |
| .web-search-info p { | |
| margin: 5px 0; | |
| font-size: 0.95em; | |
| line-height: 1.4; | |
| } | |
| /* Loading animation */ | |
| .generating { | |
| border-color: #4CAF50 !important; | |
| animation: pulse 2s infinite !important; | |
| } | |
| @keyframes pulse { | |
| 0% { | |
| box-shadow: 0 0 0 0 rgba(76, 175, 80, 0.7); | |
| } | |
| 70% { | |
| box-shadow: 0 0 0 10px rgba(76, 175, 80, 0); | |
| } | |
| 100% { | |
| box-shadow: 0 0 0 0 rgba(76, 175, 80, 0); | |
| } | |
| } | |
| """ | |
| # Gradio UI | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| gr.HTML(""" | |
| <div class="container"> | |
| <h1>π₯― BAGEL - Bootstrapping Aligned Generation with Exponential Learning</h1> | |
| <p style="text-align: center; color: #e0e0e0; font-size: 1.2em; margin-bottom: 30px;"> | |
| Advanced AI Model for Text-to-Image, Image Editing, and Image Understanding | |
| </p> | |
| </div> | |
| """) | |
| with gr.Tab("π Text to Image"): | |
| txt_input = gr.Textbox( | |
| label="Prompt", | |
| value="A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere.", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| use_web_search = gr.Checkbox( | |
| label="π Enable Web Search", | |
| value=False, | |
| info="Search the web for current information to enhance your prompt" | |
| ) | |
| show_thinking = gr.Checkbox(label="π Show Thinking Process", value=False) | |
| # Web Search Information Box | |
| web_search_info = gr.HTML(""" | |
| <div class="web-search-info" style="display: none;"> | |
| <h4>π Brave Web Search Integration</h4> | |
| <p>When enabled, BAGEL will search the web for relevant information about your prompt and incorporate current trends, references, and context into the image generation process.</p> | |
| <p>This is particularly useful for:</p> | |
| <ul style="margin-left: 20px;"> | |
| <li>β’ Current events and trending topics</li> | |
| <li>β’ Specific art styles or references</li> | |
| <li>β’ Technical or specialized subjects</li> | |
| <li>β’ Pop culture references</li> | |
| </ul> | |
| </div> | |
| """, visible=False) | |
| # Show/hide web search info based on checkbox | |
| def toggle_search_info(use_search): | |
| return gr.update(visible=use_search) | |
| use_web_search.change(toggle_search_info, inputs=[use_web_search], outputs=[web_search_info]) | |
| # Add hyperparameter controls in an accordion | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| # εζ°δΈζδΈ€δΈͺεΈε± | |
| with gr.Group(): | |
| with gr.Row(): | |
| seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, | |
| label="Seed", info="0 for random seed, positive for reproducible results") | |
| image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"], | |
| value="1:1", label="Image Ratio", | |
| info="The longer size is fixed to 1024") | |
| with gr.Row(): | |
| cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True, | |
| label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0)") | |
| cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1, | |
| label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)") | |
| with gr.Row(): | |
| cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"], | |
| value="global", label="CFG Renorm Type", | |
| info="If the genrated image is blurry, use 'global'") | |
| cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, | |
| label="CFG Renorm Min", info="1.0 disables CFG-Renorm") | |
| with gr.Row(): | |
| num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True, | |
| label="Timesteps", info="Total denoising steps") | |
| timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, interactive=True, | |
| label="Timestep Shift", info="Higher values for layout, lower for details") | |
| # Thinking parameters in a single row | |
| thinking_params = gr.Group(visible=False) | |
| with thinking_params: | |
| with gr.Row(): | |
| do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation") | |
| max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True, | |
| label="Max Think Tokens", info="Maximum number of tokens for thinking") | |
| text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True, | |
| label="Temperature", info="Controls randomness in text generation") | |
| thinking_output = gr.Textbox(label="Thinking Process", visible=False) | |
| img_output = gr.Image(label="Generated Image", elem_classes=["image-frame"]) | |
| gen_btn = gr.Button("π¨ Generate Image", variant="primary", size="lg") | |
| # Dynamically show/hide thinking process box and parameters | |
| def update_thinking_visibility(show): | |
| return gr.update(visible=show), gr.update(visible=show) | |
| show_thinking.change( | |
| fn=update_thinking_visibility, | |
| inputs=[show_thinking], | |
| outputs=[thinking_output, thinking_params] | |
| ) | |
| gr.on( | |
| triggers=[gen_btn.click, txt_input.submit], | |
| fn=text_to_image, | |
| inputs=[ | |
| txt_input, use_web_search, show_thinking, cfg_text_scale, | |
| cfg_interval, timestep_shift, | |
| num_timesteps, cfg_renorm_min, cfg_renorm_type, | |
| max_think_token_n, do_sample, text_temperature, seed, image_ratio | |
| ], | |
| outputs=[img_output, thinking_output] | |
| ) | |
| with gr.Tab("ποΈ Image Edit"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| edit_image_input = gr.Image(label="Input Image", value=load_example_image('test_images/women.jpg'), elem_classes=["image-frame"]) | |
| edit_prompt = gr.Textbox( | |
| label="Edit Prompt", | |
| value="She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes.", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| edit_image_output = gr.Image(label="Edited Result", elem_classes=["image-frame"]) | |
| edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False) | |
| with gr.Row(): | |
| edit_use_web_search = gr.Checkbox( | |
| label="π Enable Web Search", | |
| value=False, | |
| info="Search for references and context to improve editing" | |
| ) | |
| edit_show_thinking = gr.Checkbox(label="π Show Thinking Process", value=False) | |
| # Add hyperparameter controls in an accordion | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| with gr.Group(): | |
| with gr.Row(): | |
| edit_seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, interactive=True, | |
| label="Seed", info="0 for random seed, positive for reproducible results") | |
| edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True, | |
| label="CFG Text Scale", info="Controls how strongly the model follows the text prompt") | |
| with gr.Row(): | |
| edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, interactive=True, | |
| label="CFG Image Scale", info="Controls how much the model preserves input image details") | |
| edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, | |
| label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)") | |
| with gr.Row(): | |
| edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"], | |
| value="text_channel", label="CFG Renorm Type", | |
| info="If the genrated image is blurry, use 'global") | |
| edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, | |
| label="CFG Renorm Min", info="1.0 disables CFG-Renorm") | |
| with gr.Row(): | |
| edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True, | |
| label="Timesteps", info="Total denoising steps") | |
| edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, interactive=True, | |
| label="Timestep Shift", info="Higher values for layout, lower for details") | |
| # Thinking parameters in a single row | |
| edit_thinking_params = gr.Group(visible=False) | |
| with edit_thinking_params: | |
| with gr.Row(): | |
| edit_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation") | |
| edit_max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True, | |
| label="Max Think Tokens", info="Maximum number of tokens for thinking") | |
| edit_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True, | |
| label="Temperature", info="Controls randomness in text generation") | |
| edit_btn = gr.Button("βοΈ Apply Edit", variant="primary", size="lg") | |
| # Dynamically show/hide thinking process box for editing | |
| def update_edit_thinking_visibility(show): | |
| return gr.update(visible=show), gr.update(visible=show) | |
| edit_show_thinking.change( | |
| fn=update_edit_thinking_visibility, | |
| inputs=[edit_show_thinking], | |
| outputs=[edit_thinking_output, edit_thinking_params] | |
| ) | |
| gr.on( | |
| triggers=[edit_btn.click, edit_prompt.submit], | |
| fn=edit_image, | |
| inputs=[ | |
| edit_image_input, edit_prompt, edit_use_web_search, edit_show_thinking, | |
| edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval, | |
| edit_timestep_shift, edit_num_timesteps, | |
| edit_cfg_renorm_min, edit_cfg_renorm_type, | |
| edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed | |
| ], | |
| outputs=[edit_image_output, edit_thinking_output] | |
| ) | |
| with gr.Tab("πΌοΈ Image Understanding"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_input = gr.Image(label="Input Image", value=load_example_image('test_images/meme.jpg'), elem_classes=["image-frame"]) | |
| understand_prompt = gr.Textbox( | |
| label="Question", | |
| value="Can someone explain what's funny about this meme??", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| txt_output = gr.Textbox(label="AI Response", lines=20) | |
| with gr.Row(): | |
| understand_use_web_search = gr.Checkbox( | |
| label="π Enable Web Search", | |
| value=False, | |
| info="Search for context and references to better understand the image" | |
| ) | |
| understand_show_thinking = gr.Checkbox(label="π Show Thinking Process", value=False) | |
| # Add hyperparameter controls in an accordion | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| with gr.Row(): | |
| understand_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation") | |
| understand_text_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True, | |
| label="Temperature", info="Controls randomness in text generation (0=deterministic, 1=creative)") | |
| understand_max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=512, step=64, interactive=True, | |
| label="Max New Tokens", info="Maximum length of generated text, including potential thinking") | |
| img_understand_btn = gr.Button("π Analyze Image", variant="primary", size="lg") | |
| gr.on( | |
| triggers=[img_understand_btn.click, understand_prompt.submit], | |
| fn=image_understanding, | |
| inputs=[ | |
| img_input, understand_prompt, understand_use_web_search, understand_show_thinking, | |
| understand_do_sample, understand_text_temperature, understand_max_new_tokens | |
| ], | |
| outputs=txt_output | |
| ) | |
| demo.launch(share=True) |