edgellm / backend /services /chat_service.py
wu981526092's picture
add
6a50e97
raw
history blame
6.73 kB
"""
Chat generation service supporting both local models and API calls
"""
import torch
from typing import Tuple
from openai import OpenAI
from .model_service import model_service
from ..config import AVAILABLE_MODELS, API_KEY, BASE_URL
class ChatService:
def __init__(self):
# Initialize OpenAI client for API calls
self.api_client = OpenAI(
api_key=API_KEY,
base_url=BASE_URL
) if API_KEY else None
def _generate_api_response(
self,
prompt: str,
model_name: str,
messages: list = None,
system_prompt: str = None,
temperature: float = 0.7,
max_new_tokens: int = 1024
) -> Tuple[str, str, str, bool]:
"""Generate response using API"""
if not self.api_client:
raise ValueError("API client not configured. Please check API_KEY.")
# Build messages with conversation history
api_messages = []
if system_prompt:
api_messages.append({"role": "system", "content": system_prompt})
# Add conversation history
if messages:
for msg in messages:
api_messages.append({"role": msg.get("role"), "content": msg.get("content")})
# Add current prompt as the latest user message
api_messages.append({"role": "user", "content": prompt})
model_info = AVAILABLE_MODELS[model_name]
try:
# Make API call
completion = self.api_client.chat.completions.create(
model=model_name,
messages=api_messages,
temperature=temperature,
max_tokens=max_new_tokens,
stream=False
)
generated_text = completion.choices[0].message.content
# Parse thinking vs final content for thinking models
thinking_content = ""
final_content = generated_text
if model_info["supports_thinking"] and "<thinking>" in generated_text:
parts = generated_text.split("<thinking>")
if len(parts) > 1:
thinking_part = parts[1]
if "</thinking>" in thinking_part:
thinking_content = thinking_part.split("</thinking>")[0].strip()
remaining = thinking_part.split("</thinking>", 1)[1] if "</thinking>" in thinking_part else ""
final_content = remaining.strip()
return (
thinking_content,
final_content,
model_name,
model_info["supports_thinking"]
)
except Exception as e:
raise ValueError(f"API call failed: {str(e)}")
def _generate_local_response(
self,
prompt: str,
model_name: str,
messages: list = None,
system_prompt: str = None,
temperature: float = 0.7,
max_new_tokens: int = 1024
) -> Tuple[str, str, str, bool]:
"""Generate response using local model"""
if not model_service.is_model_loaded(model_name):
raise ValueError(f"Model {model_name} is not loaded")
# Get model and tokenizer
model_data = model_service.models_cache[model_name]
model = model_data["model"]
tokenizer = model_data["tokenizer"]
model_info = AVAILABLE_MODELS[model_name]
# Build the conversation with full history
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
# Add conversation history
if messages:
for msg in messages:
conversation.append({"role": msg.get("role"), "content": msg.get("content")})
# Add current prompt as the latest user message
conversation.append({"role": "user", "content": prompt})
# Apply chat template
formatted_prompt = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
# Tokenize
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode
generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Parse thinking vs final content for thinking models
thinking_content = ""
final_content = generated_text
if model_info["supports_thinking"] and "<thinking>" in generated_text:
parts = generated_text.split("<thinking>")
if len(parts) > 1:
thinking_part = parts[1]
if "</thinking>" in thinking_part:
thinking_content = thinking_part.split("</thinking>")[0].strip()
remaining = thinking_part.split("</thinking>", 1)[1] if "</thinking>" in thinking_part else ""
final_content = remaining.strip()
return (
thinking_content,
final_content,
model_name,
model_info["supports_thinking"]
)
def generate_response(
self,
prompt: str,
model_name: str,
messages: list = None,
system_prompt: str = None,
temperature: float = 0.7,
max_new_tokens: int = 1024
) -> Tuple[str, str, str, bool]:
"""
Generate chat response using appropriate method (API or local)
Returns: (thinking_content, final_content, model_used, supports_thinking)
"""
model_info = AVAILABLE_MODELS.get(model_name)
if not model_info:
raise ValueError(f"Unknown model: {model_name}")
# Route to appropriate generation method
if model_info["type"] == "api":
return self._generate_api_response(
prompt, model_name, messages, system_prompt, temperature, max_new_tokens
)
else:
return self._generate_local_response(
prompt, model_name, messages, system_prompt, temperature, max_new_tokens
)
# Global chat service instance
chat_service = ChatService()