File size: 6,725 Bytes
d8e039b 6a50e97 d8e039b 6a50e97 d8e039b 6a50e97 d8e039b 6a50e97 d8e039b 6a50e97 d8e039b 6a50e97 d8e039b 6a50e97 d8e039b 6a50e97 d8e039b 6a50e97 d8e039b 6a50e97 d8e039b 6a50e97 d8e039b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
"""
Chat generation service supporting both local models and API calls
"""
import torch
from typing import Tuple
from openai import OpenAI
from .model_service import model_service
from ..config import AVAILABLE_MODELS, API_KEY, BASE_URL
class ChatService:
def __init__(self):
# Initialize OpenAI client for API calls
self.api_client = OpenAI(
api_key=API_KEY,
base_url=BASE_URL
) if API_KEY else None
def _generate_api_response(
self,
prompt: str,
model_name: str,
messages: list = None,
system_prompt: str = None,
temperature: float = 0.7,
max_new_tokens: int = 1024
) -> Tuple[str, str, str, bool]:
"""Generate response using API"""
if not self.api_client:
raise ValueError("API client not configured. Please check API_KEY.")
# Build messages with conversation history
api_messages = []
if system_prompt:
api_messages.append({"role": "system", "content": system_prompt})
# Add conversation history
if messages:
for msg in messages:
api_messages.append({"role": msg.get("role"), "content": msg.get("content")})
# Add current prompt as the latest user message
api_messages.append({"role": "user", "content": prompt})
model_info = AVAILABLE_MODELS[model_name]
try:
# Make API call
completion = self.api_client.chat.completions.create(
model=model_name,
messages=api_messages,
temperature=temperature,
max_tokens=max_new_tokens,
stream=False
)
generated_text = completion.choices[0].message.content
# Parse thinking vs final content for thinking models
thinking_content = ""
final_content = generated_text
if model_info["supports_thinking"] and "<thinking>" in generated_text:
parts = generated_text.split("<thinking>")
if len(parts) > 1:
thinking_part = parts[1]
if "</thinking>" in thinking_part:
thinking_content = thinking_part.split("</thinking>")[0].strip()
remaining = thinking_part.split("</thinking>", 1)[1] if "</thinking>" in thinking_part else ""
final_content = remaining.strip()
return (
thinking_content,
final_content,
model_name,
model_info["supports_thinking"]
)
except Exception as e:
raise ValueError(f"API call failed: {str(e)}")
def _generate_local_response(
self,
prompt: str,
model_name: str,
messages: list = None,
system_prompt: str = None,
temperature: float = 0.7,
max_new_tokens: int = 1024
) -> Tuple[str, str, str, bool]:
"""Generate response using local model"""
if not model_service.is_model_loaded(model_name):
raise ValueError(f"Model {model_name} is not loaded")
# Get model and tokenizer
model_data = model_service.models_cache[model_name]
model = model_data["model"]
tokenizer = model_data["tokenizer"]
model_info = AVAILABLE_MODELS[model_name]
# Build the conversation with full history
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
# Add conversation history
if messages:
for msg in messages:
conversation.append({"role": msg.get("role"), "content": msg.get("content")})
# Add current prompt as the latest user message
conversation.append({"role": "user", "content": prompt})
# Apply chat template
formatted_prompt = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
# Tokenize
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode
generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Parse thinking vs final content for thinking models
thinking_content = ""
final_content = generated_text
if model_info["supports_thinking"] and "<thinking>" in generated_text:
parts = generated_text.split("<thinking>")
if len(parts) > 1:
thinking_part = parts[1]
if "</thinking>" in thinking_part:
thinking_content = thinking_part.split("</thinking>")[0].strip()
remaining = thinking_part.split("</thinking>", 1)[1] if "</thinking>" in thinking_part else ""
final_content = remaining.strip()
return (
thinking_content,
final_content,
model_name,
model_info["supports_thinking"]
)
def generate_response(
self,
prompt: str,
model_name: str,
messages: list = None,
system_prompt: str = None,
temperature: float = 0.7,
max_new_tokens: int = 1024
) -> Tuple[str, str, str, bool]:
"""
Generate chat response using appropriate method (API or local)
Returns: (thinking_content, final_content, model_used, supports_thinking)
"""
model_info = AVAILABLE_MODELS.get(model_name)
if not model_info:
raise ValueError(f"Unknown model: {model_name}")
# Route to appropriate generation method
if model_info["type"] == "api":
return self._generate_api_response(
prompt, model_name, messages, system_prompt, temperature, max_new_tokens
)
else:
return self._generate_local_response(
prompt, model_name, messages, system_prompt, temperature, max_new_tokens
)
# Global chat service instance
chat_service = ChatService()
|