Spaces:
Running
on
T4
Running
on
T4
| from constants import IMAGE_PER_CONVERSATION_LIMIT, DEFAULT_SYSTEM_PREAMBLE_TOKEN_COUNT, VISION_COHERE_MODEL_NAME, VISION_MODEL_TOKEN_LIMIT | |
| from prompt_examples import AYA_VISION_PROMPT_EXAMPLES | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import logging | |
| import cohere | |
| import os | |
| import traceback | |
| import random | |
| import gradio as gr | |
| # from dotenv import load_dotenv | |
| # load_dotenv() | |
| MULTIMODAL_API_KEY = os.getenv("AYA_VISION_API_KEY") | |
| logger = logging.getLogger(__name__) | |
| aya_vision_client = cohere.ClientV2( | |
| api_key=MULTIMODAL_API_KEY, | |
| client_name="c4ai-aya-vision-hf-space" | |
| ) | |
| def cohere_vision_chat(chat_history, model=VISION_COHERE_MODEL_NAME): | |
| response = aya_vision_client.chat( | |
| messages=chat_history, | |
| model=model, | |
| ) | |
| return response.message.content[0].text | |
| def get_aya_vision_prompt_example(language): | |
| example = AYA_VISION_PROMPT_EXAMPLES[language] | |
| print("example:", example) | |
| print("example prompt:", example[0]) | |
| print("example image:", example[1]) | |
| return example[0], example[1] | |
| def get_base64_from_local_file(file_path): | |
| try: | |
| print("loading image") | |
| with open(file_path, "rb") as image_file: | |
| base64_image = base64.b64encode(image_file.read()).decode('utf-8') | |
| print("converted image") | |
| return base64_image | |
| except Exception as e: | |
| logger.debug(f"Error converting local image to base64 string: {e}") | |
| return None | |
| def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5): | |
| print("incoming message:", incoming_message) | |
| print("image_filepath:", image_filepath) | |
| max_size_bytes = max_size_mb * 1024 * 1024 | |
| image_ext = image_filepath.lower() | |
| if image_ext.endswith(".jpg") or image_ext.endswith('.jpeg'): | |
| image_type="image/jpeg" | |
| elif image_ext.endswith(".png"): | |
| image_type = "image/png" | |
| elif image_ext.endswith(".webp"): | |
| image_type="image/webp" | |
| elif image_ext.endswith(".gif"): | |
| image_type="image/gif" | |
| response="" | |
| chat_history = [] | |
| print("converting image to base 64") | |
| base64_image = get_base64_from_local_file(image_filepath) | |
| image = f"data:{image_type};base64,{base64_image}" | |
| print("Image base64:", image[:30]) | |
| # to prevent Cohere API from throwing error for empty message | |
| if incoming_message=="" or incoming_message is None: | |
| incoming_message="." | |
| chat_history.append( | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": incoming_message}, | |
| {"type": "image_url","image_url": { "url": image}}], | |
| } | |
| ) | |
| image_size_bytes = get_base64_image_size(image) | |
| if image_size_bytes >= max_size_bytes: | |
| gr.Error("Please upload image with size under 5MB") | |
| # response = cohere_vision_chat_stream(chat_history, model=VISION_COHERE_MODEL_NAME) | |
| # return response | |
| res = aya_vision_client.chat_stream(messages=chat_history,model=VISION_COHERE_MODEL_NAME) | |
| output = "" | |
| for event in res: | |
| if event: | |
| if event.type == "content-delta": | |
| output += event.delta.message.content.text | |
| yield output | |
| def get_base64_image_size(base64_string): | |
| if ',' in base64_string: | |
| base64_data = base64_string.split(',', 1)[1] | |
| else: | |
| base64_data = base64_string | |
| base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '') | |
| padding = base64_data.count('=') | |
| size_bytes = (len(base64_data) * 3) // 4 - padding | |
| return size_bytes |