luciagomez commited on
Commit
328eb52
·
verified ·
1 Parent(s): 36ea207

Create chat.py

Browse files
Files changed (1) hide show
  1. chat.py +111 -0
chat.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from threading import Thread
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForCausalLM,
7
+ TextIteratorStreamer,
8
+ )
9
+
10
+ MODEL_ID = os.environ.get("MODEL_ID", "swiss-ai/Apertus-8B-Instruct-2509")
11
+
12
+ # ---- Load model & tokenizer once at startup
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
14
+
15
+ # dtype: prefer bfloat16 on GPU (A100/T4 support), else float32 for CPU
16
+ if torch.cuda.is_available():
17
+ torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
18
+ else:
19
+ torch_dtype = torch.float32
20
+
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ MODEL_ID,
23
+ device_map="auto", # accelerate will shard across available devices
24
+ torch_dtype=torch_dtype,
25
+ trust_remote_code=True,
26
+ )
27
+
28
+ # Ensure we have an EOS if needed
29
+ eos_token_id = tokenizer.eos_token_id
30
+
31
+
32
+ def _apply_chat_template_with_fallback(messages):
33
+ """
34
+ Apply the tokenizer's chat template if present; otherwise, fall back to a simple format.
35
+ Returns a string prompt (not tokenized).
36
+ """
37
+ try:
38
+ return tokenizer.apply_chat_template(
39
+ messages,
40
+ tokenize=False,
41
+ add_generation_prompt=True,
42
+ )
43
+ except Exception:
44
+ # Fallback formatting
45
+ parts = []
46
+ for m in messages:
47
+ role = m.get("role", "user")
48
+ content = m.get("content", "")
49
+ parts.append(f"<|{role}|>\n{content}\n")
50
+ parts.append("<|assistant|>\n")
51
+ return "\n".join(parts)
52
+
53
+
54
+ def chat_with_model(message, history_messages, perspective):
55
+ """
56
+ Streaming generator for Gradio (Chatbot type='messages').
57
+ Inputs:
58
+ - message: str
59
+ - history_messages: list[{'role': 'user'|'assistant', 'content': str}]
60
+ - perspective: str (system message, optional)
61
+ Yields:
62
+ - (updated_messages_for_chatbot, updated_messages_for_state)
63
+ """
64
+ # Compose chat messages for this turn
65
+ chat_msgs = []
66
+ if perspective and perspective.strip():
67
+ chat_msgs.append({"role": "system", "content": perspective.strip()})
68
+
69
+ # Append prior turns from UI state (already in messages format)
70
+ for m in history_messages:
71
+ if "role" in m and "content" in m:
72
+ chat_msgs.append({"role": m["role"], "content": m["content"]})
73
+
74
+ # Add the new user message
75
+ chat_msgs.append({"role": "user", "content": message})
76
+
77
+ # Build the prompt with the model's chat template
78
+ prompt_text = _apply_chat_template_with_fallback(chat_msgs)
79
+
80
+ inputs = tokenizer(prompt_text, return_tensors="pt")
81
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
82
+
83
+ # Set up streamer for token-wise output
84
+ streamer = TextIteratorStreamer(
85
+ tokenizer=tokenizer,
86
+ skip_prompt=True,
87
+ skip_special_tokens=True,
88
+ )
89
+
90
+ gen_kwargs = dict(
91
+ **inputs,
92
+ streamer=streamer,
93
+ max_new_tokens=512,
94
+ do_sample=True,
95
+ temperature=0.7,
96
+ top_p=0.9,
97
+ eos_token_id=eos_token_id,
98
+ )
99
+
100
+ # Launch generation in a background thread
101
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
102
+ thread.start()
103
+
104
+ # Start building the assistant reply incrementally
105
+ reply = ""
106
+ base = history_messages + [{"role": "user", "content": message}]
107
+
108
+ for token_text in streamer:
109
+ reply += token_text
110
+ updated = base + [{"role": "assistant", "content": reply}]
111
+ yield updated, updated