Spaces:
Runtime error
Runtime error
| from typing import Optional, Union | |
| import weave | |
| from openai import OpenAI | |
| from openai.types.chat import ChatCompletion | |
| class OpenAIModel(weave.Model): | |
| model_name: str | |
| _openai_client: OpenAI | |
| def __init__(self, model_name: str = "gpt-4o") -> None: | |
| super().__init__(model_name=model_name) | |
| self._openai_client = OpenAI() | |
| def create_messages( | |
| self, | |
| user_prompts: Union[str, list[str]], | |
| system_prompt: Optional[str] = None, | |
| messages: Optional[list[dict]] = None, | |
| ) -> list[dict]: | |
| user_prompts = [user_prompts] if isinstance(user_prompts, str) else user_prompts | |
| messages = list(messages) if isinstance(messages, dict) else [] | |
| for user_prompt in user_prompts: | |
| messages.append({"role": "user", "content": user_prompt}) | |
| if system_prompt is not None: | |
| messages = [{"role": "system", "content": system_prompt}] + messages | |
| return messages | |
| def predict( | |
| self, | |
| user_prompts: Union[str, list[str]], | |
| system_prompt: Optional[str] = None, | |
| messages: Optional[list[dict]] = None, | |
| **kwargs, | |
| ) -> ChatCompletion: | |
| messages = self.create_messages(user_prompts, system_prompt, messages) | |
| if "response_format" in kwargs: | |
| response = self._openai_client.beta.chat.completions.parse( | |
| model=self.model_name, messages=messages, **kwargs | |
| ) | |
| else: | |
| response = self._openai_client.chat.completions.create( | |
| model=self.model_name, messages=messages, **kwargs | |
| ) | |
| return response | |