iamthewalrus67 commited on
Commit
863688d
·
1 Parent(s): c5d24cb

Add chat template

Browse files
Files changed (1) hide show
  1. app.py +22 -11
app.py CHANGED
@@ -12,7 +12,10 @@ MODEL_ID = "le-llm/gemma-3-12b-it-reasoning"
12
  # Load model & tokenizer
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16 if device=="cuda" else torch.float32).to(device)
 
 
 
16
 
17
  SYSTEM_PROMPT = "You are a friendly Chatbot."
18
 
@@ -24,14 +27,17 @@ def respond(
24
  temperature,
25
  top_p,
26
  ):
 
 
 
 
27
 
28
- conversation = system_message + "\n"
29
- for turn in history:
30
- role = "User" if turn["role"] == "user" else "Assistant"
31
- conversation += f"{role}: {turn['content']}\n"
32
- conversation += f"User: {message}\nAssistant:"
33
-
34
- inputs = tokenizer(conversation, return_tensors="pt").to(device)
35
 
36
  output_ids = model.generate(
37
  **inputs,
@@ -39,10 +45,15 @@ def respond(
39
  temperature=temperature,
40
  top_p=top_p,
41
  do_sample=True,
 
42
  )
43
 
44
- response = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
45
- yield response
 
 
 
 
46
 
47
  chatbot = gr.ChatInterface(
48
  respond,
@@ -61,4 +72,4 @@ chatbot = gr.ChatInterface(
61
  ],
62
  )
63
 
64
- chatbot.launch()
 
12
  # Load model & tokenizer
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ MODEL_ID,
17
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
18
+ ).to(device)
19
 
20
  SYSTEM_PROMPT = "You are a friendly Chatbot."
21
 
 
27
  temperature,
28
  top_p,
29
  ):
30
+ # Build conversation in chat template format
31
+ messages = [{"role": "system", "content": system_message}] + history + [
32
+ {"role": "user", "content": message}
33
+ ]
34
 
35
+ input_text = tokenizer.apply_chat_template(
36
+ messages,
37
+ tokenize=False,
38
+ add_generation_prompt=True # ensures model knows it's assistant's turn
39
+ )
40
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
 
41
 
42
  output_ids = model.generate(
43
  **inputs,
 
45
  temperature=temperature,
46
  top_p=top_p,
47
  do_sample=True,
48
+ eos_token_id=tokenizer.eos_token_id, # stop at EOS
49
  )
50
 
51
+ # Only return the newly generated assistant message
52
+ response = tokenizer.decode(
53
+ output_ids[0][inputs["input_ids"].shape[1]:],
54
+ skip_special_tokens=True
55
+ )
56
+ return response
57
 
58
  chatbot = gr.ChatInterface(
59
  respond,
 
72
  ],
73
  )
74
 
75
+ chatbot.launch()