AperPhil / chat.py
luciagomez's picture
Create chat.py
328eb52 verified
import os
import torch
from threading import Thread
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
MODEL_ID = os.environ.get("MODEL_ID", "swiss-ai/Apertus-8B-Instruct-2509")
# ---- Load model & tokenizer once at startup
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
# dtype: prefer bfloat16 on GPU (A100/T4 support), else float32 for CPU
if torch.cuda.is_available():
torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
torch_dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto", # accelerate will shard across available devices
torch_dtype=torch_dtype,
trust_remote_code=True,
)
# Ensure we have an EOS if needed
eos_token_id = tokenizer.eos_token_id
def _apply_chat_template_with_fallback(messages):
"""
Apply the tokenizer's chat template if present; otherwise, fall back to a simple format.
Returns a string prompt (not tokenized).
"""
try:
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
except Exception:
# Fallback formatting
parts = []
for m in messages:
role = m.get("role", "user")
content = m.get("content", "")
parts.append(f"<|{role}|>\n{content}\n")
parts.append("<|assistant|>\n")
return "\n".join(parts)
def chat_with_model(message, history_messages, perspective):
"""
Streaming generator for Gradio (Chatbot type='messages').
Inputs:
- message: str
- history_messages: list[{'role': 'user'|'assistant', 'content': str}]
- perspective: str (system message, optional)
Yields:
- (updated_messages_for_chatbot, updated_messages_for_state)
"""
# Compose chat messages for this turn
chat_msgs = []
if perspective and perspective.strip():
chat_msgs.append({"role": "system", "content": perspective.strip()})
# Append prior turns from UI state (already in messages format)
for m in history_messages:
if "role" in m and "content" in m:
chat_msgs.append({"role": m["role"], "content": m["content"]})
# Add the new user message
chat_msgs.append({"role": "user", "content": message})
# Build the prompt with the model's chat template
prompt_text = _apply_chat_template_with_fallback(chat_msgs)
inputs = tokenizer(prompt_text, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Set up streamer for token-wise output
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
gen_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=eos_token_id,
)
# Launch generation in a background thread
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
# Start building the assistant reply incrementally
reply = ""
base = history_messages + [{"role": "user", "content": message}]
for token_text in streamer:
reply += token_text
updated = base + [{"role": "assistant", "content": reply}]
yield updated, updated