Spaces:
Running
Running
| import os | |
| import gradio as gr | |
| import asyncio | |
| from dotenv import load_dotenv | |
| from huggingface_hub import InferenceClient, hf_hub_download, model_info | |
| from functools import partial | |
| # Load environment variables | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise ValueError("Please set HF_TOKEN environment variable") | |
| # Available models | |
| AVAILABLE_MODELS = [ | |
| "HuggingFaceH4/zephyr-7b-beta", | |
| "NousResearch/Hermes-3-Llama-3.1-8B", | |
| "mistralai/Mistral-Nemo-Base-2407", | |
| "meta-llama/Llama-2-70b-hf", | |
| "aaditya/Llama3-OpenBioLLM-8B", | |
| ] | |
| # Initialize inference client | |
| inference_client = InferenceClient(token=HF_TOKEN) | |
| def get_model_card_html(model_name, title): | |
| """Fetch and format model card information.""" | |
| try: | |
| info = model_info(model_name, token=HF_TOKEN) | |
| return f""" | |
| <div class="model-card-container"> | |
| <h3>{info.modelId}</h3> | |
| <p><strong>Pipeline Tag:</strong> {info.pipeline_tag or 'Not specified'}</p> | |
| <p><strong>Downloads:</strong> {info.downloads:,}</p> | |
| <p><strong>Likes:</strong> {info.likes:,}</p> | |
| <p><a href="https://huggingface.co/{model_name}" target="_blank">View on Hugging Face</a></p> | |
| </div> | |
| """ | |
| except Exception as e: | |
| return f""" | |
| <div class="model-card-container"> | |
| <h3>{model_name}</h3> | |
| <p>Unable to load full model card information.</p> | |
| <p><a href="https://huggingface.co/{model_name}" target="_blank">View on Hugging Face</a></p> | |
| </div> | |
| """ | |
| async def get_model_response(prompt, model_name, temperature_value, do_sample, max_tokens): | |
| """Get response from a Hugging Face model.""" | |
| try: | |
| # Build kwargs dynamically | |
| generation_args = { | |
| "prompt": prompt, | |
| "model": model_name, | |
| "max_new_tokens": max_tokens, | |
| "do_sample": do_sample, | |
| "return_full_text": False | |
| } | |
| # Only include temperature if sampling is enabled | |
| if do_sample and temperature_value > 0: | |
| generation_args["temperature"] = temperature_value | |
| # Run the inference in a thread pool to not block the event loop | |
| loop = asyncio.get_event_loop() | |
| response = await loop.run_in_executor( | |
| None, | |
| partial(inference_client.text_generation, **generation_args) | |
| ) | |
| # Check if response might be truncated | |
| if len(response) >= max_tokens * 4: # Rough estimate of tokens to characters ratio | |
| response += "\n\n[Warning: Response may have been truncated. Try increasing the max tokens if the response seems incomplete.]" | |
| return response | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| async def process_single_response(prompt, model_name, temp, do_sample, max_tokens, chatbot): | |
| """Process a single model response and update its chatbot.""" | |
| response = await get_model_response(prompt, model_name, temp, do_sample, max_tokens) | |
| chat_history = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}] | |
| return chat_history | |
| async def compare_models(prompt, model1, model2, temp1, temp2, do_sample1, do_sample2, max_tokens1, max_tokens2): | |
| """Compare outputs from two selected models.""" | |
| if not prompt.strip(): | |
| empty_response = [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Please enter a prompt"}] | |
| yield empty_response, empty_response, gr.update(interactive=True) | |
| return # Exit the generator | |
| # Initialize with "Generating..." messages | |
| initial_message = [{"role": "user", "content": prompt}, {"role": "assistant", "content": "Generating..."}] | |
| yield initial_message, initial_message, gr.update(interactive=False) | |
| # Create tasks for both model responses | |
| task1 = asyncio.create_task(process_single_response(prompt, model1, temp1, do_sample1, max_tokens1, "chatbot1")) | |
| task2 = asyncio.create_task(process_single_response(prompt, model2, temp2, do_sample2, max_tokens2, "chatbot2")) | |
| chat1 = chat2 = initial_message | |
| start_time = asyncio.get_event_loop().time() | |
| try: | |
| while not (task1.done() and task2.done()): | |
| # Update the messages with elapsed time | |
| elapsed = round(asyncio.get_event_loop().time() - start_time, 1) | |
| chat1_content = chat1[1]["content"] | |
| chat2_content = chat2[1]["content"] | |
| if not task1.done(): | |
| chat1 = [{"role": "user", "content": prompt}, | |
| {"role": "assistant", "content": f"Generating... ({elapsed:.1f}s)"}] | |
| if not task2.done(): | |
| chat2 = [{"role": "user", "content": prompt}, | |
| {"role": "assistant", "content": f"Generating... ({elapsed:.1f}s)"}] | |
| # Check if any task completed | |
| done, pending = await asyncio.wait([t for t in [task1, task2] if not t.done()], | |
| timeout=0.1, | |
| return_when=asyncio.FIRST_COMPLETED) | |
| for task in done: | |
| if task == task1: | |
| chat1 = await task1 | |
| else: | |
| chat2 = await task2 | |
| yield chat1, chat2, gr.update(interactive=False) | |
| # Ensure we have both final results | |
| if not task1.done(): | |
| chat1 = await task1 | |
| if not task2.done(): | |
| chat2 = await task2 | |
| # Final yield with both results | |
| yield chat1, chat2, gr.update(interactive=True) | |
| except Exception as e: | |
| error_message = [{"role": "user", "content": prompt}, {"role": "assistant", "content": f"Error: {str(e)}"}] | |
| yield error_message, error_message, gr.update(interactive=True) | |
| # Update temperature slider interactivity based on sampling checkbox | |
| def update_slider_state(enabled): | |
| return [ | |
| gr.update(interactive=enabled), | |
| gr.update( | |
| elem_classes=[] if enabled else ["disabled-slider"], | |
| value=0 if not enabled else None | |
| ) | |
| ] | |
| # Create the Gradio interface | |
| with gr.Blocks(css=""" | |
| .disabled-slider { opacity: 0.5; pointer-events: none; } | |
| .model-card-container { | |
| background-color: #f8f9fa; | |
| font-size: 14px; | |
| color: #666; | |
| } | |
| .model-card-container h3 { | |
| margin: 0; | |
| color: black; | |
| } | |
| .model-card-container p { | |
| margin: 5px 0; | |
| } | |
| """) as demo: | |
| gr.Markdown("# LLM Comparison Tool") | |
| gr.Markdown("Using HuggingFace's Inference API, compare outputs from different `text-generation` models side by side.") | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| label="Enter your prompt", | |
| placeholder="Type your prompt here...", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Generate Responses") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model1_dropdown = gr.Dropdown( | |
| choices=AVAILABLE_MODELS, | |
| value=AVAILABLE_MODELS[0], | |
| label="Select Model 1" | |
| ) | |
| model1_card = gr.HTML( | |
| value=get_model_card_html(AVAILABLE_MODELS[0], "Model 1 Information"), | |
| elem_classes=["model-card-container"] | |
| ) | |
| do_sample1 = gr.Checkbox( | |
| label="Enable sampling (random outputs)", | |
| value=False | |
| ) | |
| temp1 = gr.Slider( | |
| label="Temperature (Higher = more creative, lower = more predictable)", | |
| minimum=0, | |
| maximum=1, | |
| step=0.1, | |
| value=0.0, | |
| interactive=False, | |
| elem_classes=["disabled-slider"] | |
| ) | |
| max_tokens1 = gr.Slider( | |
| label="Maximum new tokens in response", | |
| minimum=10, | |
| maximum=2000, | |
| step=10, | |
| value=10 | |
| ) | |
| chatbot1 = gr.Chatbot( | |
| label="Model 1 Output", | |
| show_label=True, | |
| height=300, | |
| type="messages" | |
| ) | |
| with gr.Column(): | |
| model2_dropdown = gr.Dropdown( | |
| choices=AVAILABLE_MODELS, | |
| value=AVAILABLE_MODELS[1], | |
| label="Select Model 2" | |
| ) | |
| model2_card = gr.HTML( | |
| value=get_model_card_html(AVAILABLE_MODELS[1], "Model 2 Information"), | |
| elem_classes=["model-card-container"] | |
| ) | |
| do_sample2 = gr.Checkbox( | |
| label="Enable sampling (random outputs)", | |
| value=False | |
| ) | |
| temp2 = gr.Slider( | |
| label="Temperature (Higher = more creative, lower = more predictable)", | |
| minimum=0, | |
| maximum=1, | |
| step=0.1, | |
| value=0.0, | |
| interactive=False, | |
| elem_classes=["disabled-slider"] | |
| ) | |
| max_tokens2 = gr.Slider( | |
| label="Maximum new tokens in response", | |
| minimum=10, | |
| maximum=2000, | |
| step=10, | |
| value=10 | |
| ) | |
| chatbot2 = gr.Chatbot( | |
| label="Model 2 Output", | |
| show_label=True, | |
| height=300, | |
| type="messages" | |
| ) | |
| def start_loading(): | |
| return gr.update(interactive=False) | |
| # Handle form submission | |
| submit_btn.click( | |
| fn=start_loading, | |
| inputs=None, | |
| outputs=submit_btn, | |
| queue=False | |
| ).then( | |
| fn=compare_models, | |
| inputs=[prompt, model1_dropdown, model2_dropdown, temp1, temp2, do_sample1, do_sample2, max_tokens1, max_tokens2], | |
| outputs=[chatbot1, chatbot2, submit_btn], | |
| queue=True # Enable queuing for streaming updates | |
| ) | |
| # Update model cards when models are changed | |
| model1_dropdown.change( | |
| fn=lambda x: get_model_card_html(x, "Model 1 Information"), | |
| inputs=[model1_dropdown], | |
| outputs=[model1_card] | |
| ) | |
| model2_dropdown.change( | |
| fn=lambda x: get_model_card_html(x, "Model 2 Information"), | |
| inputs=[model2_dropdown], | |
| outputs=[model2_card] | |
| ) | |
| # Existing event handlers | |
| do_sample1.change( | |
| fn=update_slider_state, | |
| inputs=[do_sample1], | |
| outputs=[temp1, temp1] | |
| ) | |
| do_sample2.change( | |
| fn=update_slider_state, | |
| inputs=[do_sample2], | |
| outputs=[temp2, temp2] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |