File size: 4,290 Bytes
0b9851a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
#!/usr/bin/env python3
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model and tokenizer (same as server.py)
model_name = "models/Llama-3.2-1B-Instruct"
tok = None
lm = None
def chat_current(system_prompt: str, user_prompt: str) -> str:
"""
Current implementation (same as server.py) - will show warnings
"""
print("🔴 Running CURRENT implementation (with warnings)...")
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
input_ids = tok.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(lm.device)
with torch.inference_mode():
output_ids = lm.generate(
input_ids, # No attention_mask, no pad_token_id
max_new_tokens=2048,
do_sample=True,
temperature=0.2,
repetition_penalty=1.1,
top_k=100,
top_p=0.95,
)
answer = tok.decode(
output_ids[0][input_ids.shape[-1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
return answer.strip()
com_add = "5F71XTGBnBGzxiPxCK4EbWMnhckH21tGWSRfe6NrMdxMe6kg"
def chat_fixed(system_prompt: str, user_prompt: str) -> str:
"""
Fixed implementation - proper attention mask and pad token
"""
print("🟢 Running FIXED implementation (no warnings)...")
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
# Get both input_ids and attention_mask
inputs = tok.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True # Returns dict with input_ids and attention_mask
)
# Move to device
input_ids = inputs["input_ids"].to(lm.device)
attention_mask = inputs["attention_mask"].to(lm.device)
with torch.inference_mode():
output_ids = lm.generate(
input_ids=input_ids,
attention_mask=attention_mask, # Proper attention mask
pad_token_id=tok.eos_token_id, # Explicit pad token
max_new_tokens=2048,
do_sample=True,
temperature=0.2,
repetition_penalty=1.1,
top_k=100,
top_p=0.95,
)
answer = tok.decode(
output_ids[0][input_ids.shape[-1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
return answer.strip()
def compare_generations():
"""Compare both implementations"""
system_prompt = "You are a helpful assistant who tries to help answer the user's question."
user_prompt = "Create a report on anxiety in work. How do I manage time and stress effectively?"
print("=" * 60)
print("COMPARING GENERATION METHODS")
print("=" * 60)
print(f"System: {system_prompt}")
print(f"User: {user_prompt}")
print("=" * 60)
# Test current implementation
print("\n" + "=" * 60)
current_output = chat_current(system_prompt, user_prompt)
print(f"CURRENT OUTPUT:\n{current_output}")
print("\n" + "=" * 60)
# Test fixed implementation
fixed_output = chat_fixed(system_prompt, user_prompt)
print(f"FIXED OUTPUT:\n{fixed_output}")
print("\n" + "=" * 60)
print("COMPARISON:")
print(f"Outputs are identical: {current_output == fixed_output}")
print(f"Current length: {len(current_output)} chars")
print(f"Fixed length: {len(fixed_output)} chars")
# if __name__ == "__main__":
# # Set pad token for the fixed version
# if tok.pad_token is None:
# tok.pad_token = tok.eos_token
# compare_generations()
def filter_by_word_count(data, max_words=3):
"""Return only phrases with word count <= max_words."""
return {k: v for k, v in data.items() if len(v.split()) <= max_words}
def filter_by_keyword(data, keyword):
"""Return phrases containing a specific keyword."""
return {k: v for k, v in data.items() if keyword.lower() in v.lower()}
example_prompt = "As an answer of 5 points with scale from 5 to 10. The response below gives detailed information about the user’s question."
|