File size: 3,734 Bytes
a303b0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import spaces
import torch
from huggingface_hub import hf_hub_download
from transformers import Mistral3ForConditionalGeneration, AutoTokenizer
from typing import Any, List, Dict
import base64
import mimetypes
from pathlib import Path

def load_system_prompt(repo_id: str, filename: str) -> dict[str, Any]:
    file_path = hf_hub_download(repo_id=repo_id, filename=filename)
    with open(file_path, "r") as file:
        system_prompt = file.read()

    index_begin_think = system_prompt.find("[THINK]")
    index_end_think = system_prompt.find("[/THINK]")

    return {
        "role": "system",
        "content": [
            {"type": "text", "text": system_prompt[:index_begin_think]},
            {
                "type": "text",
                "text": system_prompt[index_end_think + len("[/THINK]") :],
            },
        ],
    }

model_id = "mistralai/Magistral-Small-2509"
tokenizer = AutoTokenizer.from_pretrained(model_id, tokenizer_type="mistral")
model = Mistral3ForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, device_map="auto"
).eval()


SYSTEM_PROMPT = load_system_prompt(model_id, "SYSTEM_PROMPT.txt")

@spaces.GPU(duration=120)
def predict(message: dict, history: list) -> str:
    # Build messages for the model from history
    messages = [SYSTEM_PROMPT]
    for user_msg, assistant_msg in history:
        messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
        if assistant_msg:
             messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})

    # Process current user message (with potential image)
    user_content = [{"type": "text", "text": message['text']}]
    if message['files']:
        # Assuming one image file from multimodal textbox
        image_path = Path(message['files'][0])
        image_bytes = image_path.read_bytes()
        encoded_image = base64.b64encode(image_bytes).decode("utf-8")
        mime_type, _ = mimetypes.guess_type(image_path)
        if mime_type is None:
            mime_type = "image/png"
        data_url = f"data:{mime_type};base64,{encoded_image}"
        user_content.append({"type": "image_url", "image_url": {"url": data_url}})

    messages.append({"role": "user", "content": user_content})

    tokenized = tokenizer.apply_chat_template(messages, return_dict=True)

    input_ids = torch.tensor(tokenized.input_ids, device="cuda").unsqueeze(0)
    attention_mask = torch.tensor(tokenized.attention_mask, device="cuda").unsqueeze(0)

    if 'pixel_values' in tokenized and len(tokenized.pixel_values) > 0:
        pixel_values = torch.tensor(
            tokenized.pixel_values[0], dtype=torch.bfloat16, device="cuda"
        ).unsqueeze(0)
        image_sizes = torch.tensor(pixel_values.shape[-2:], device="cuda").unsqueeze(0)
        output = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_sizes=image_sizes,
        )[0]
    else:
        output = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )[0]

    decoded_output = tokenizer.decode(
        output[
            len(tokenized.input_ids) : (
                -1 if output[-1] == tokenizer.eos_token_id else len(output)
            )
        ]
    )
    return decoded_output

demo = gr.ChatInterface(
    fn=predict,
    multimodal=True,
    title="Magistral Chat App",
    description='Chat with Magistral AI. Upload an image if relevant to your question.<br>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a>',
)

if __name__ == "__main__":
    demo.launch()