|
|
""" |
|
|
Chat generation service supporting both local models and API calls |
|
|
""" |
|
|
import torch |
|
|
from typing import Tuple |
|
|
from openai import OpenAI |
|
|
from .model_service import model_service |
|
|
from ..config import AVAILABLE_MODELS, API_KEY, BASE_URL |
|
|
|
|
|
|
|
|
class ChatService: |
|
|
def __init__(self): |
|
|
|
|
|
self.api_client = OpenAI( |
|
|
api_key=API_KEY, |
|
|
base_url=BASE_URL |
|
|
) if API_KEY else None |
|
|
|
|
|
def _generate_api_response( |
|
|
self, |
|
|
prompt: str, |
|
|
model_name: str, |
|
|
messages: list = None, |
|
|
system_prompt: str = None, |
|
|
temperature: float = 0.7, |
|
|
max_new_tokens: int = 1024 |
|
|
) -> Tuple[str, str, str, bool]: |
|
|
"""Generate response using API""" |
|
|
if not self.api_client: |
|
|
raise ValueError("API client not configured. Please check API_KEY.") |
|
|
|
|
|
|
|
|
api_messages = [] |
|
|
if system_prompt: |
|
|
api_messages.append({"role": "system", "content": system_prompt}) |
|
|
|
|
|
|
|
|
if messages: |
|
|
for msg in messages: |
|
|
api_messages.append({"role": msg.get("role"), "content": msg.get("content")}) |
|
|
|
|
|
|
|
|
api_messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
model_info = AVAILABLE_MODELS[model_name] |
|
|
|
|
|
try: |
|
|
|
|
|
completion = self.api_client.chat.completions.create( |
|
|
model=model_name, |
|
|
messages=api_messages, |
|
|
temperature=temperature, |
|
|
max_tokens=max_new_tokens, |
|
|
stream=False |
|
|
) |
|
|
|
|
|
generated_text = completion.choices[0].message.content |
|
|
|
|
|
|
|
|
thinking_content = "" |
|
|
final_content = generated_text |
|
|
|
|
|
if model_info["supports_thinking"] and "<thinking>" in generated_text: |
|
|
parts = generated_text.split("<thinking>") |
|
|
if len(parts) > 1: |
|
|
thinking_part = parts[1] |
|
|
if "</thinking>" in thinking_part: |
|
|
thinking_content = thinking_part.split("</thinking>")[0].strip() |
|
|
remaining = thinking_part.split("</thinking>", 1)[1] if "</thinking>" in thinking_part else "" |
|
|
final_content = remaining.strip() |
|
|
|
|
|
return ( |
|
|
thinking_content, |
|
|
final_content, |
|
|
model_name, |
|
|
model_info["supports_thinking"] |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
raise ValueError(f"API call failed: {str(e)}") |
|
|
|
|
|
def _generate_local_response( |
|
|
self, |
|
|
prompt: str, |
|
|
model_name: str, |
|
|
messages: list = None, |
|
|
system_prompt: str = None, |
|
|
temperature: float = 0.7, |
|
|
max_new_tokens: int = 1024 |
|
|
) -> Tuple[str, str, str, bool]: |
|
|
"""Generate response using local model""" |
|
|
if not model_service.is_model_loaded(model_name): |
|
|
raise ValueError(f"Model {model_name} is not loaded") |
|
|
|
|
|
|
|
|
model_data = model_service.models_cache[model_name] |
|
|
model = model_data["model"] |
|
|
tokenizer = model_data["tokenizer"] |
|
|
model_info = AVAILABLE_MODELS[model_name] |
|
|
|
|
|
|
|
|
conversation = [] |
|
|
if system_prompt: |
|
|
conversation.append({"role": "system", "content": system_prompt}) |
|
|
|
|
|
|
|
|
if messages: |
|
|
for msg in messages: |
|
|
conversation.append({"role": msg.get("role"), "content": msg.get("content")}) |
|
|
|
|
|
|
|
|
conversation.append({"role": "user", "content": prompt}) |
|
|
|
|
|
|
|
|
formatted_prompt = tokenizer.apply_chat_template( |
|
|
conversation, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generated_tokens = outputs[0][inputs['input_ids'].shape[1]:] |
|
|
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
thinking_content = "" |
|
|
final_content = generated_text |
|
|
|
|
|
if model_info["supports_thinking"] and "<thinking>" in generated_text: |
|
|
parts = generated_text.split("<thinking>") |
|
|
if len(parts) > 1: |
|
|
thinking_part = parts[1] |
|
|
if "</thinking>" in thinking_part: |
|
|
thinking_content = thinking_part.split("</thinking>")[0].strip() |
|
|
remaining = thinking_part.split("</thinking>", 1)[1] if "</thinking>" in thinking_part else "" |
|
|
final_content = remaining.strip() |
|
|
|
|
|
return ( |
|
|
thinking_content, |
|
|
final_content, |
|
|
model_name, |
|
|
model_info["supports_thinking"] |
|
|
) |
|
|
|
|
|
def generate_response( |
|
|
self, |
|
|
prompt: str, |
|
|
model_name: str, |
|
|
messages: list = None, |
|
|
system_prompt: str = None, |
|
|
temperature: float = 0.7, |
|
|
max_new_tokens: int = 1024 |
|
|
) -> Tuple[str, str, str, bool]: |
|
|
""" |
|
|
Generate chat response using appropriate method (API or local) |
|
|
Returns: (thinking_content, final_content, model_used, supports_thinking) |
|
|
""" |
|
|
model_info = AVAILABLE_MODELS.get(model_name) |
|
|
if not model_info: |
|
|
raise ValueError(f"Unknown model: {model_name}") |
|
|
|
|
|
|
|
|
if model_info["type"] == "api": |
|
|
return self._generate_api_response( |
|
|
prompt, model_name, messages, system_prompt, temperature, max_new_tokens |
|
|
) |
|
|
else: |
|
|
return self._generate_local_response( |
|
|
prompt, model_name, messages, system_prompt, temperature, max_new_tokens |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
chat_service = ChatService() |
|
|
|