"""
Chat generation service supporting both local models and API calls
"""
import torch
from typing import Tuple
from openai import OpenAI
from .model_service import model_service
from ..config import AVAILABLE_MODELS, API_KEY, BASE_URL
class ChatService:
def __init__(self):
# Initialize OpenAI client for API calls
self.api_client = OpenAI(
api_key=API_KEY,
base_url=BASE_URL
) if API_KEY else None
def _generate_api_response(
self,
prompt: str,
model_name: str,
messages: list = None,
system_prompt: str = None,
temperature: float = 0.7,
max_new_tokens: int = 1024
) -> Tuple[str, str, str, bool]:
"""Generate response using API"""
if not self.api_client:
raise ValueError("API client not configured. Please check API_KEY.")
# Build messages with conversation history
api_messages = []
if system_prompt:
api_messages.append({"role": "system", "content": system_prompt})
# Add conversation history
if messages:
for msg in messages:
api_messages.append({"role": msg.get("role"), "content": msg.get("content")})
# Add current prompt as the latest user message
api_messages.append({"role": "user", "content": prompt})
model_info = AVAILABLE_MODELS[model_name]
try:
# Make API call
completion = self.api_client.chat.completions.create(
model=model_name,
messages=api_messages,
temperature=temperature,
max_tokens=max_new_tokens,
stream=False
)
generated_text = completion.choices[0].message.content
# Parse thinking vs final content for thinking models
thinking_content = ""
final_content = generated_text
if model_info["supports_thinking"] and "" in generated_text:
parts = generated_text.split("")
if len(parts) > 1:
thinking_part = parts[1]
if "" in thinking_part:
thinking_content = thinking_part.split("")[0].strip()
remaining = thinking_part.split("", 1)[1] if "" in thinking_part else ""
final_content = remaining.strip()
return (
thinking_content,
final_content,
model_name,
model_info["supports_thinking"]
)
except Exception as e:
raise ValueError(f"API call failed: {str(e)}")
def _generate_local_response(
self,
prompt: str,
model_name: str,
messages: list = None,
system_prompt: str = None,
temperature: float = 0.7,
max_new_tokens: int = 1024
) -> Tuple[str, str, str, bool]:
"""Generate response using local model"""
if not model_service.is_model_loaded(model_name):
raise ValueError(f"Model {model_name} is not loaded")
# Get model and tokenizer
model_data = model_service.models_cache[model_name]
model = model_data["model"]
tokenizer = model_data["tokenizer"]
model_info = AVAILABLE_MODELS[model_name]
# Build the conversation with full history
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
# Add conversation history
if messages:
for msg in messages:
conversation.append({"role": msg.get("role"), "content": msg.get("content")})
# Add current prompt as the latest user message
conversation.append({"role": "user", "content": prompt})
# Apply chat template
formatted_prompt = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
# Tokenize
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode
generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Parse thinking vs final content for thinking models
thinking_content = ""
final_content = generated_text
if model_info["supports_thinking"] and "" in generated_text:
parts = generated_text.split("")
if len(parts) > 1:
thinking_part = parts[1]
if "" in thinking_part:
thinking_content = thinking_part.split("")[0].strip()
remaining = thinking_part.split("", 1)[1] if "" in thinking_part else ""
final_content = remaining.strip()
return (
thinking_content,
final_content,
model_name,
model_info["supports_thinking"]
)
def generate_response(
self,
prompt: str,
model_name: str,
messages: list = None,
system_prompt: str = None,
temperature: float = 0.7,
max_new_tokens: int = 1024
) -> Tuple[str, str, str, bool]:
"""
Generate chat response using appropriate method (API or local)
Returns: (thinking_content, final_content, model_used, supports_thinking)
"""
model_info = AVAILABLE_MODELS.get(model_name)
if not model_info:
raise ValueError(f"Unknown model: {model_name}")
# Route to appropriate generation method
if model_info["type"] == "api":
return self._generate_api_response(
prompt, model_name, messages, system_prompt, temperature, max_new_tokens
)
else:
return self._generate_local_response(
prompt, model_name, messages, system_prompt, temperature, max_new_tokens
)
# Global chat service instance
chat_service = ChatService()