Spaces:
Running
on
Zero
Running
on
Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| import base64 | |
| import logging | |
| import os | |
| from io import BytesIO | |
| from typing import Optional | |
| import yaml | |
| from openai import AzureOpenAI, OpenAI # pip install openai | |
| from PIL import Image | |
| from tenacity import ( | |
| retry, | |
| stop_after_attempt, | |
| stop_after_delay, | |
| wait_random_exponential, | |
| ) | |
| from embodied_gen.utils.process_media import combine_images_to_grid | |
| logging.getLogger("httpx").setLevel(logging.WARNING) | |
| logging.basicConfig(level=logging.WARNING) | |
| logger = logging.getLogger(__name__) | |
| __all__ = [ | |
| "GPTclient", | |
| ] | |
| CONFIG_FILE = "embodied_gen/utils/gpt_config.yaml" | |
| class GPTclient: | |
| """A client to interact with the GPT model via OpenAI or Azure API.""" | |
| def __init__( | |
| self, | |
| endpoint: str, | |
| api_key: str, | |
| model_name: str = "yfb-gpt-4o", | |
| api_version: str = None, | |
| check_connection: bool = True, | |
| verbose: bool = False, | |
| ): | |
| if api_version is not None: | |
| self.client = AzureOpenAI( | |
| azure_endpoint=endpoint, | |
| api_key=api_key, | |
| api_version=api_version, | |
| ) | |
| else: | |
| self.client = OpenAI( | |
| base_url=endpoint, | |
| api_key=api_key, | |
| ) | |
| self.endpoint = endpoint | |
| self.model_name = model_name | |
| self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"} | |
| self.verbose = verbose | |
| if check_connection: | |
| self.check_connection() | |
| logger.info(f"Using GPT model: {self.model_name}.") | |
| def completion_with_backoff(self, **kwargs): | |
| return self.client.chat.completions.create(**kwargs) | |
| def query( | |
| self, | |
| text_prompt: str, | |
| image_base64: Optional[list[str | Image.Image]] = None, | |
| system_role: Optional[str] = None, | |
| params: Optional[dict] = None, | |
| ) -> Optional[str]: | |
| """Queries the GPT model with a text and optional image prompts. | |
| Args: | |
| text_prompt (str): The main text input that the model responds to. | |
| image_base64 (Optional[List[str]]): A list of image base64 strings | |
| or local image paths or PIL.Image to accompany the text prompt. | |
| system_role (Optional[str]): Optional system-level instructions | |
| that specify the behavior of the assistant. | |
| params (Optional[dict]): Additional parameters for GPT setting. | |
| Returns: | |
| Optional[str]: The response content generated by the model based on | |
| the prompt. Returns `None` if an error occurs. | |
| """ | |
| if system_role is None: | |
| system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa | |
| content_user = [ | |
| { | |
| "type": "text", | |
| "text": text_prompt, | |
| }, | |
| ] | |
| # Process images if provided | |
| if image_base64 is not None: | |
| if not isinstance(image_base64, list): | |
| image_base64 = [image_base64] | |
| # Hardcode tmp because of the openrouter can't input multi images. | |
| if "openrouter" in self.endpoint: | |
| image_base64 = combine_images_to_grid(image_base64) | |
| for img in image_base64: | |
| if isinstance(img, Image.Image): | |
| buffer = BytesIO() | |
| img.save(buffer, format=img.format or "PNG") | |
| buffer.seek(0) | |
| image_binary = buffer.read() | |
| img = base64.b64encode(image_binary).decode("utf-8") | |
| elif ( | |
| len(os.path.splitext(img)) > 1 | |
| and os.path.splitext(img)[-1].lower() in self.image_formats | |
| ): | |
| if not os.path.exists(img): | |
| raise FileNotFoundError(f"Image file not found: {img}") | |
| with open(img, "rb") as f: | |
| img = base64.b64encode(f.read()).decode("utf-8") | |
| content_user.append( | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/png;base64,{img}"}, | |
| } | |
| ) | |
| payload = { | |
| "messages": [ | |
| {"role": "system", "content": system_role}, | |
| {"role": "user", "content": content_user}, | |
| ], | |
| "temperature": 0.1, | |
| "max_tokens": 500, | |
| "top_p": 0.1, | |
| "frequency_penalty": 0, | |
| "presence_penalty": 0, | |
| "stop": None, | |
| "model": self.model_name, | |
| } | |
| if params: | |
| payload.update(params) | |
| response = None | |
| try: | |
| response = self.completion_with_backoff(**payload) | |
| response = response.choices[0].message.content | |
| except Exception as e: | |
| logger.error(f"Error GPTclint {self.endpoint} API call: {e}") | |
| response = None | |
| if self.verbose: | |
| logger.info(f"Prompt: {text_prompt}") | |
| logger.info(f"Response: {response}") | |
| return response | |
| def check_connection(self) -> None: | |
| """Check whether the GPT API connection is working.""" | |
| try: | |
| response = self.completion_with_backoff( | |
| messages=[ | |
| {"role": "system", "content": "You are a test system."}, | |
| {"role": "user", "content": "Hello"}, | |
| ], | |
| model=self.model_name, | |
| temperature=0, | |
| max_tokens=100, | |
| ) | |
| content = response.choices[0].message.content | |
| logger.info(f"Connection check success.") | |
| except Exception as e: | |
| raise ConnectionError( | |
| f"Failed to connect to GPT API at {self.endpoint}, " | |
| f"please check setting in `{CONFIG_FILE}` and `README`." | |
| ) | |
| with open(CONFIG_FILE, "r") as f: | |
| config = yaml.safe_load(f) | |
| agent_type = config["agent_type"] | |
| agent_config = config.get(agent_type, {}) | |
| # Prefer environment variables, fallback to YAML config | |
| endpoint = os.environ.get("ENDPOINT", agent_config.get("endpoint")) | |
| api_key = os.environ.get("API_KEY", agent_config.get("api_key")) | |
| api_version = os.environ.get("API_VERSION", agent_config.get("api_version")) | |
| model_name = os.environ.get("MODEL_NAME", agent_config.get("model_name")) | |
| GPT_CLIENT = GPTclient( | |
| endpoint=endpoint, | |
| api_key=api_key, | |
| api_version=api_version, | |
| model_name=model_name, | |
| check_connection=False, | |
| ) | |