Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from threading import Thread | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TextIteratorStreamer, | |
| ) | |
| MODEL_ID = os.environ.get("MODEL_ID", "swiss-ai/Apertus-8B-Instruct-2509") | |
| # ---- Load model & tokenizer once at startup | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) | |
| # dtype: prefer bfloat16 on GPU (A100/T4 support), else float32 for CPU | |
| if torch.cuda.is_available(): | |
| torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| else: | |
| torch_dtype = torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", # accelerate will shard across available devices | |
| torch_dtype=torch_dtype, | |
| trust_remote_code=True, | |
| ) | |
| # Ensure we have an EOS if needed | |
| eos_token_id = tokenizer.eos_token_id | |
| def _apply_chat_template_with_fallback(messages): | |
| """ | |
| Apply the tokenizer's chat template if present; otherwise, fall back to a simple format. | |
| Returns a string prompt (not tokenized). | |
| """ | |
| try: | |
| return tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| except Exception: | |
| # Fallback formatting | |
| parts = [] | |
| for m in messages: | |
| role = m.get("role", "user") | |
| content = m.get("content", "") | |
| parts.append(f"<|{role}|>\n{content}\n") | |
| parts.append("<|assistant|>\n") | |
| return "\n".join(parts) | |
| def chat_with_model(message, history_messages, perspective): | |
| """ | |
| Streaming generator for Gradio (Chatbot type='messages'). | |
| Inputs: | |
| - message: str | |
| - history_messages: list[{'role': 'user'|'assistant', 'content': str}] | |
| - perspective: str (system message, optional) | |
| Yields: | |
| - (updated_messages_for_chatbot, updated_messages_for_state) | |
| """ | |
| # Compose chat messages for this turn | |
| chat_msgs = [] | |
| if perspective and perspective.strip(): | |
| chat_msgs.append({"role": "system", "content": perspective.strip()}) | |
| # Append prior turns from UI state (already in messages format) | |
| for m in history_messages: | |
| if "role" in m and "content" in m: | |
| chat_msgs.append({"role": m["role"], "content": m["content"]}) | |
| # Add the new user message | |
| chat_msgs.append({"role": "user", "content": message}) | |
| # Build the prompt with the model's chat template | |
| prompt_text = _apply_chat_template_with_fallback(chat_msgs) | |
| inputs = tokenizer(prompt_text, return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| # Set up streamer for token-wise output | |
| streamer = TextIteratorStreamer( | |
| tokenizer=tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| gen_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| eos_token_id=eos_token_id, | |
| ) | |
| # Launch generation in a background thread | |
| thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| # Start building the assistant reply incrementally | |
| reply = "" | |
| base = history_messages + [{"role": "user", "content": message}] | |
| for token_text in streamer: | |
| reply += token_text | |
| updated = base + [{"role": "assistant", "content": reply}] | |
| yield updated, updated |