Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from transformers import Mistral3ForConditionalGeneration, AutoTokenizer | |
| from typing import Any, List, Dict | |
| import base64 | |
| import mimetypes | |
| from pathlib import Path | |
| def load_system_prompt(repo_id: str, filename: str) -> dict[str, Any]: | |
| file_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| with open(file_path, "r") as file: | |
| system_prompt = file.read() | |
| index_begin_think = system_prompt.find("[THINK]") | |
| index_end_think = system_prompt.find("[/THINK]") | |
| return { | |
| "role": "system", | |
| "content": [ | |
| {"type": "text", "text": system_prompt[:index_begin_think]}, | |
| { | |
| "type": "text", | |
| "text": system_prompt[index_end_think + len("[/THINK]") :], | |
| }, | |
| ], | |
| } | |
| model_id = "mistralai/Magistral-Small-2509" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, tokenizer_type="mistral") | |
| model = Mistral3ForConditionalGeneration.from_pretrained( | |
| model_id, torch_dtype=torch.bfloat16, device_map="auto" | |
| ).eval() | |
| SYSTEM_PROMPT = load_system_prompt(model_id, "SYSTEM_PROMPT.txt") | |
| def predict(message: dict, history: list) -> str: | |
| # Build messages for the model from history | |
| messages = [SYSTEM_PROMPT] | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}) | |
| # Process current user message (with potential image) | |
| user_content = [{"type": "text", "text": message['text']}] | |
| if message['files']: | |
| # Assuming one image file from multimodal textbox | |
| image_path = Path(message['files'][0]) | |
| image_bytes = image_path.read_bytes() | |
| encoded_image = base64.b64encode(image_bytes).decode("utf-8") | |
| mime_type, _ = mimetypes.guess_type(image_path) | |
| if mime_type is None: | |
| mime_type = "image/png" | |
| data_url = f"data:{mime_type};base64,{encoded_image}" | |
| user_content.append({"type": "image_url", "image_url": {"url": data_url}}) | |
| messages.append({"role": "user", "content": user_content}) | |
| tokenized = tokenizer.apply_chat_template(messages, return_dict=True) | |
| input_ids = torch.tensor(tokenized.input_ids, device="cuda").unsqueeze(0) | |
| attention_mask = torch.tensor(tokenized.attention_mask, device="cuda").unsqueeze(0) | |
| if 'pixel_values' in tokenized and len(tokenized.pixel_values) > 0: | |
| pixel_values = torch.tensor( | |
| tokenized.pixel_values[0], dtype=torch.bfloat16, device="cuda" | |
| ).unsqueeze(0) | |
| image_sizes = torch.tensor(pixel_values.shape[-2:], device="cuda").unsqueeze(0) | |
| output = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| pixel_values=pixel_values, | |
| image_sizes=image_sizes, | |
| )[0] | |
| else: | |
| output = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| )[0] | |
| decoded_output = tokenizer.decode( | |
| output[ | |
| len(tokenized.input_ids) : ( | |
| -1 if output[-1] == tokenizer.eos_token_id else len(output) | |
| ) | |
| ] | |
| ) | |
| return decoded_output | |
| demo = gr.ChatInterface( | |
| fn=predict, | |
| multimodal=True, | |
| title="Magistral Chat App", | |
| description='Chat with Magistral AI. Upload an image if relevant to your question.<br>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a>', | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |