File size: 3,460 Bytes
328eb52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import torch
from threading import Thread
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer,
)

MODEL_ID = os.environ.get("MODEL_ID", "swiss-ai/Apertus-8B-Instruct-2509")

# ---- Load model & tokenizer once at startup
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)

# dtype: prefer bfloat16 on GPU (A100/T4 support), else float32 for CPU
if torch.cuda.is_available():
    torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
    torch_dtype = torch.float32

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",           # accelerate will shard across available devices
    torch_dtype=torch_dtype,
    trust_remote_code=True,
)

# Ensure we have an EOS if needed
eos_token_id = tokenizer.eos_token_id


def _apply_chat_template_with_fallback(messages):
    """
    Apply the tokenizer's chat template if present; otherwise, fall back to a simple format.
    Returns a string prompt (not tokenized).
    """
    try:
        return tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
    except Exception:
        # Fallback formatting
        parts = []
        for m in messages:
            role = m.get("role", "user")
            content = m.get("content", "")
            parts.append(f"<|{role}|>\n{content}\n")
        parts.append("<|assistant|>\n")
        return "\n".join(parts)


def chat_with_model(message, history_messages, perspective):
    """
    Streaming generator for Gradio (Chatbot type='messages').
    Inputs:
      - message: str
      - history_messages: list[{'role': 'user'|'assistant', 'content': str}]
      - perspective: str (system message, optional)
    Yields:
      - (updated_messages_for_chatbot, updated_messages_for_state)
    """
    # Compose chat messages for this turn
    chat_msgs = []
    if perspective and perspective.strip():
        chat_msgs.append({"role": "system", "content": perspective.strip()})

    # Append prior turns from UI state (already in messages format)
    for m in history_messages:
        if "role" in m and "content" in m:
            chat_msgs.append({"role": m["role"], "content": m["content"]})

    # Add the new user message
    chat_msgs.append({"role": "user", "content": message})

    # Build the prompt with the model's chat template
    prompt_text = _apply_chat_template_with_fallback(chat_msgs)

    inputs = tokenizer(prompt_text, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Set up streamer for token-wise output
    streamer = TextIteratorStreamer(
        tokenizer=tokenizer,
        skip_prompt=True,
        skip_special_tokens=True,
    )

    gen_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=512,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        eos_token_id=eos_token_id,
    )

    # Launch generation in a background thread
    thread = Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    # Start building the assistant reply incrementally
    reply = ""
    base = history_messages + [{"role": "user", "content": message}]

    for token_text in streamer:
        reply += token_text
        updated = base + [{"role": "assistant", "content": reply}]
        yield updated, updated