|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
|
|
|
model_name = "models/Llama-3.2-1B-Instruct" |
|
|
tok = AutoTokenizer.from_pretrained(model_name) |
|
|
lm = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="cuda", |
|
|
).eval() |
|
|
|
|
|
def chat_current(system_prompt: str, user_prompt: str) -> str: |
|
|
""" |
|
|
Current implementation (same as server.py) - will show warnings |
|
|
""" |
|
|
print("🔴 Running CURRENT implementation (with warnings)...") |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt}, |
|
|
] |
|
|
|
|
|
input_ids = tok.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt" |
|
|
).to(lm.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
output_ids = lm.generate( |
|
|
input_ids, |
|
|
max_new_tokens=2048, |
|
|
do_sample=True, |
|
|
temperature=0.2, |
|
|
repetition_penalty=1.1, |
|
|
top_k=100, |
|
|
top_p=0.95, |
|
|
) |
|
|
|
|
|
answer = tok.decode( |
|
|
output_ids[0][input_ids.shape[-1]:], |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=True, |
|
|
) |
|
|
return answer.strip() |
|
|
|
|
|
|
|
|
def chat_fixed(system_prompt: str, user_prompt: str) -> str: |
|
|
""" |
|
|
Fixed implementation - proper attention mask and pad token |
|
|
""" |
|
|
print("🟢 Running FIXED implementation (no warnings)...") |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt}, |
|
|
] |
|
|
|
|
|
|
|
|
inputs = tok.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt", |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
|
|
|
input_ids = inputs["input_ids"].to(lm.device) |
|
|
attention_mask = inputs["attention_mask"].to(lm.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
output_ids = lm.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
pad_token_id=tok.eos_token_id, |
|
|
max_new_tokens=2048, |
|
|
do_sample=True, |
|
|
temperature=0.2, |
|
|
repetition_penalty=1.1, |
|
|
top_k=100, |
|
|
top_p=0.95, |
|
|
) |
|
|
|
|
|
answer = tok.decode( |
|
|
output_ids[0][input_ids.shape[-1]:], |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=True, |
|
|
) |
|
|
return answer.strip() |
|
|
|
|
|
|
|
|
def compare_generations(): |
|
|
"""Compare both implementations""" |
|
|
system_prompt = "You are a helpful assistant who tries to help answer the user's question." |
|
|
user_prompt = "Create a report on anxiety in work. How do I manage time and stress effectively?" |
|
|
|
|
|
print("=" * 60) |
|
|
print("COMPARING GENERATION METHODS") |
|
|
print("=" * 60) |
|
|
print(f"System: {system_prompt}") |
|
|
print(f"User: {user_prompt}") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
current_output = chat_current(system_prompt, user_prompt) |
|
|
print(f"CURRENT OUTPUT:\n{current_output}") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
|
|
|
fixed_output = chat_fixed(system_prompt, user_prompt) |
|
|
print(f"FIXED OUTPUT:\n{fixed_output}") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("COMPARISON:") |
|
|
print(f"Outputs are identical: {current_output == fixed_output}") |
|
|
print(f"Current length: {len(current_output)} chars") |
|
|
print(f"Fixed length: {len(fixed_output)} chars") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
if tok.pad_token is None: |
|
|
tok.pad_token = tok.eos_token |
|
|
|
|
|
compare_generations() |