Spaces:
Running
on
T4
Running
on
T4
| from __future__ import annotations | |
| import os | |
| import io | |
| import re | |
| import time | |
| import uuid | |
| import torch | |
| import cohere | |
| import random | |
| import secrets | |
| import requests | |
| import fasttext | |
| import replicate | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from groq import Groq | |
| from TTS.api import TTS | |
| from elevenlabs import save | |
| from gradio.themes.base import Base | |
| from elevenlabs.client import ElevenLabs | |
| from huggingface_hub import hf_hub_download | |
| from gradio.themes.utils import colors, fonts, sizes | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| from prompt_examples import TEXT_CHAT_EXAMPLES, IMG_GEN_PROMPT_EXAMPLES, AUDIO_EXAMPLES, TEXT_CHAT_EXAMPLES_LABELS, IMG_GEN_PROMPT_EXAMPLES_LABELS, AUDIO_EXAMPLES_LABELS, AYA_VISION_PROMPT_EXAMPLES | |
| from preambles import CHAT_PREAMBLE, AUDIO_RESPONSE_PREAMBLE, IMG_DESCRIPTION_PREAMBLE | |
| from constants import LID_LANGUAGES, NEETS_AI_LANGID_MAP, AYA_MODEL_NAME, BATCH_SIZE, USE_ELVENLABS, USE_REPLICATE | |
| from aya_vision_utils import get_aya_vision_response, get_aya_vision_prompt_example | |
| # from dotenv import load_dotenv | |
| # load_dotenv() | |
| HF_API_TOKEN = os.getenv("HF_API_KEY") | |
| ELEVEN_LABS_KEY = os.getenv("ELEVEN_LABS_KEY") | |
| NEETS_AI_API_KEY = os.getenv("NEETS_AI_API_KEY") | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| IMG_GEN_COHERE_API_KEY = os.getenv("IMG_GEN_COHERE_API_KEY") | |
| AUDIO_COHERE_API_KEY = os.getenv("AUDIO_COHERE_API_KEY") | |
| CHAT_COHERE_API_KEY = os.getenv("CHAT_COHERE_API_KEY") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Initialize cohere clients | |
| img_prompt_client = cohere.Client( | |
| api_key=IMG_GEN_COHERE_API_KEY, | |
| client_name="c4ai-aya-expanse-img" | |
| ) | |
| chat_client = cohere.Client( | |
| api_key=CHAT_COHERE_API_KEY, | |
| client_name="c4ai-aya-expanse-chat" | |
| ) | |
| audio_response_client = cohere.Client( | |
| api_key=AUDIO_COHERE_API_KEY, | |
| client_name="c4ai-aya-expanse-audio" | |
| ) | |
| # Initialize the Groq client | |
| groq_client = Groq(api_key=GROQ_API_KEY) | |
| # Initialize the ElevenLabs client | |
| eleven_labs_client = ElevenLabs( | |
| api_key=ELEVEN_LABS_KEY, | |
| ) | |
| # Language identification | |
| lid_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin") | |
| LID_model = fasttext.load_model(lid_model_path) | |
| def predict_language(text): | |
| text = re.sub("\n", " ", text) | |
| label, logit = LID_model.predict(text) | |
| label = label[0][len("__label__") :] | |
| print("predicted language:", label) | |
| return label | |
| # Image Generation util functions | |
| def choose_img_prompt_examples(language): | |
| example_choice = random.choice(IMG_GEN_PROMPT_EXAMPLES[language]) | |
| return example_choice | |
| def get_hf_inference_api_response(payload, model_id): | |
| headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} | |
| MODEL_API_URL = f"https://api-inference.huggingface.co/models/{model_id}" | |
| response = requests.post(MODEL_API_URL, headers=headers, json=payload) | |
| return response.content | |
| def replicate_api_inference(input_prompt): | |
| input_params={ | |
| "prompt": input_prompt, | |
| "go_fast": True, | |
| "megapixels": "1", | |
| "num_outputs": 1, | |
| "aspect_ratio": "1:1", | |
| "output_format": "jpg", | |
| "output_quality": 80, | |
| "enable_safety_checker": True, | |
| "safety_tolerance": 1, | |
| "num_inference_steps": 4 | |
| } | |
| image = replicate.run("black-forest-labs/flux-schnell",input=input_params) | |
| image = Image.open(image[0]) | |
| return image | |
| def generate_image(input_prompt, model_id="black-forest-labs/FLUX.1-schnell"): | |
| if input_prompt: | |
| if USE_REPLICATE: | |
| print("using replicate for image generation") | |
| image = replicate_api_inference(input_prompt) | |
| else: | |
| try: | |
| print("using HF inference API for image generation") | |
| image_bytes = get_hf_inference_api_response({ "inputs": input_prompt}, model_id) | |
| image = np.array(Image.open(io.BytesIO(image_bytes))) | |
| except Exception as e: | |
| print("HF API error:", e) | |
| # generate image with help replicate in case of error | |
| image = replicate_api_inference(input_prompt) | |
| return image | |
| else: | |
| return None | |
| def generate_img_prompt(input_prompt): | |
| if input_prompt: | |
| # clean prompt before doing language detection | |
| cleaned_prompt = clean_text(input_prompt, remove_bullets=True, remove_newline=True) | |
| text_lang_code = predict_language(cleaned_prompt) | |
| gr.Info("Generating Image", duration=2) | |
| if text_lang_code!="eng_Latn": | |
| text = f""" | |
| Translate the given input prompt to English. | |
| Input Prompt: {input_prompt} | |
| Then based on the English translation of the prompt, generate a detailed image description which can be used to generate an image using a text-to-image model. | |
| Do not use more than 3-4 lines for the image description. Respond with only the image description. | |
| """ | |
| else: | |
| text = f"""Generate a detailed image description which can be used to generate an image using a text-to-image model based on the given input prompt: | |
| Input Prompt: {input_prompt} | |
| Do not use more than 3-4 lines for the description. | |
| """ | |
| response = img_prompt_client.chat(message=text, preamble=IMG_DESCRIPTION_PREAMBLE, model=AYA_MODEL_NAME) | |
| output = response.text | |
| return output | |
| else: | |
| return None | |
| # Chat with Aya util functions | |
| def choose_chat_examples(language): | |
| example_choice = random.choice(TEXT_CHAT_EXAMPLES[language]) | |
| return example_choice | |
| def trigger_example(example): | |
| chat, updated_history = generate_aya_chat_response(example) | |
| return chat, updated_history | |
| def generate_aya_chat_response(user_message, cid, token, history=None): | |
| if not token: | |
| print("no token") | |
| #raise gr.Error("Error loading.") | |
| if history is None: | |
| history = [] | |
| if cid == "" or None: | |
| cid = str(uuid.uuid4()) | |
| print(f"cid: {cid} prompt:{user_message}") | |
| history.append(user_message) | |
| stream = chat_client.chat_stream(message=user_message, preamble=CHAT_PREAMBLE, conversation_id=cid, model=AYA_MODEL_NAME, connectors=[], temperature=0.3) | |
| output = "" | |
| for idx, response in enumerate(stream): | |
| if response.event_type == "text-generation": | |
| output += response.text | |
| if idx == 0: | |
| history.append(" " + output) | |
| else: | |
| history[-1] = output | |
| chat = [ | |
| (history[i].strip(), history[i + 1].strip()) | |
| for i in range(0, len(history) - 1, 2) | |
| ] | |
| yield chat, history, cid | |
| return chat, history, cid | |
| def clear_chat(): | |
| return [], [], str(uuid.uuid4()) | |
| # Audio Pipeline util functions | |
| def transcribe_and_stream(inputs, model_name="groq_whisper", show_info="show_info", language="english"): | |
| if inputs: | |
| if show_info=="show_info": | |
| gr.Info("Processing Audio", duration=1) | |
| if model_name != "groq_whisper": | |
| print("DEVICE:", DEVICE) | |
| pipe = pipeline( | |
| task="automatic-speech-recognition", | |
| model=model_name, | |
| chunk_length_s=30, | |
| DEVICE=DEVICE) | |
| text = pipe(inputs, batch_size=BATCH_SIZE, return_timestamps=True)["text"] | |
| else: | |
| text = groq_whisper_tts(inputs) | |
| # stream text output | |
| for i in range(len(text)): | |
| time.sleep(0.01) | |
| yield text[: i + 10] | |
| else: | |
| return "" | |
| def aya_speech_text_response(text): | |
| if text: | |
| stream = audio_response_client.chat_stream(message=text,preamble=AUDIO_RESPONSE_PREAMBLE, model=AYA_MODEL_NAME) | |
| output = "" | |
| for event in stream: | |
| if event: | |
| if event.event_type == "text-generation": | |
| output+=event.text | |
| cleaned_output = clean_text(output) | |
| yield cleaned_output | |
| else: | |
| return "" | |
| def clean_text(text, remove_bullets=False, remove_newline=False): | |
| # Remove bold formatting | |
| cleaned_text = re.sub(r"\*\*", "", text) | |
| if remove_bullets: | |
| cleaned_text = re.sub(r"^- ", "", cleaned_text, flags=re.MULTILINE) | |
| if remove_newline: | |
| cleaned_text = re.sub(r"\n", " ", cleaned_text) | |
| return cleaned_text | |
| def convert_text_to_speech(text, language="english"): | |
| # do language detection to determine voice of speech response | |
| if text: | |
| # clean text before doing language detection | |
| cleaned_text = clean_text(text, remove_bullets=True, remove_newline=True) | |
| text_lang_code = predict_language(cleaned_text) | |
| if not USE_ELVENLABS: | |
| if text_lang_code!= "jpn_Jpan": | |
| audio_path = neetsai_tts(text, text_lang_code) | |
| else: | |
| print("DEVICE:", DEVICE) | |
| # if language is japanese then use XTTS for TTS since neets_ai doesn't support japanese voice | |
| tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(DEVICE) | |
| speaker_wav="samples/ja-sample.wav" | |
| lang_code="ja" | |
| audio_path = "./output.wav" | |
| tts.tts_to_file(text=text, speaker_wav=speaker_wav, language=lang_code, file_path=audio_path) | |
| else: | |
| # use elevenlabs for TTS | |
| audio_path = elevenlabs_generate_audio(text) | |
| return audio_path | |
| else: | |
| return None | |
| def elevenlabs_generate_audio(text): | |
| audio = eleven_labs_client.generate( | |
| text=text, | |
| voice="River", | |
| model="eleven_turbo_v2_5", #"eleven_multilingual_v2" | |
| ) | |
| # save audio | |
| audio_path = "./audio.mp3" | |
| save(audio, audio_path) | |
| return audio_path | |
| def neetsai_tts(input_text, text_lang_code): | |
| if text_lang_code in LID_LANGUAGES.keys(): | |
| language = LID_LANGUAGES[text_lang_code] | |
| else: | |
| # use english voice as default for languages outside 23 languages of Aya Expanse | |
| language = "english" | |
| neets_lang_id = NEETS_AI_LANGID_MAP[language] | |
| neets_vits_voice_id = f"vits-{neets_lang_id}" | |
| response = requests.request( | |
| method="POST", | |
| url="https://api.neets.ai/v1/tts", | |
| headers={ | |
| "Content-Type": "application/json", | |
| "X-API-Key": NEETS_AI_API_KEY | |
| }, | |
| json={ | |
| "text": input_text, | |
| "voice_id": neets_vits_voice_id, | |
| "params": { | |
| "model": "vits" | |
| } | |
| } | |
| ) | |
| # save audio file | |
| audio_path = "neets_demo.mp3" | |
| with open(audio_path, "wb") as f: | |
| f.write(response.content) | |
| return audio_path | |
| def groq_whisper_tts(filename): | |
| with open(filename, "rb") as file: | |
| transcriptions = groq_client.audio.transcriptions.create( | |
| file=(filename, file.read()), | |
| model="whisper-large-v3-turbo", | |
| response_format="json", | |
| temperature=0.0 | |
| ) | |
| print("transcribed text:", transcriptions.text) | |
| print("********************************") | |
| return transcriptions.text | |
| # setup gradio app theme | |
| theme = gr.themes.Base( | |
| primary_hue=gr.themes.colors.teal, | |
| secondary_hue=gr.themes.colors.blue, | |
| neutral_hue=gr.themes.colors.gray, | |
| text_size=gr.themes.sizes.text_lg, | |
| ).set( | |
| # Primary Button Color | |
| button_primary_background_fill="#2F70E3", #"#114A56", | |
| button_primary_background_fill_hover="#2F70E3", #"#114A56", | |
| # Block Labels | |
| block_title_text_weight="600", | |
| block_label_text_weight="600", | |
| block_label_text_size="*text_md", | |
| ) | |
| demo = gr.Blocks(theme=theme, analytics_enabled=False) | |
| with demo: | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=1): | |
| gr.Image("1.png", elem_id="logo-img", show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False) | |
| with gr.Column(scale=30): | |
| gr.Markdown("""C4AI Aya model family covers state-of-art models like Aya Vision and Aya Expanse with highly advanced capabilities to connect the world across languages. | |
| <br/> | |
| You can use this space to chat, speak, visualize and see with Aya models in 23 languages. | |
| <br/> | |
| <br/> | |
| **Model**: [aya-vision-32B](https://huggingface.co/CohereForAI/aya-vision-32b), [aya-expanse-32B](https://huggingface.co/CohereForAI/aya-expanse-32b) | |
| <br/> | |
| **Developed by**: [Cohere for AI](https://cohere.com/research) and [Cohere](https://cohere.com/) | |
| <br/> | |
| **License**: [CC-BY-NC](https://cohere.com/c4ai-cc-by-nc-license), requires also adhering to [C4AI's Acceptable Use Policy](https://docs.cohere.com/docs/c4ai-acceptable-use-policy) | |
| """ | |
| ) | |
| # Generate Images | |
| with gr.TabItem("Aya Vision") as see_with_aya: | |
| with gr.Row(): | |
| with gr.Column(): | |
| aya_vision_prompt = gr.Textbox(placeholder="Ask anything in our 23 languages ...", label="Input Prompt", lines=3) | |
| aya_vision_input_img = gr.Image(label="Input Image", interactive=True, type="filepath") | |
| submit_aya_vision = gr.Button(value="Submit", variant="primary") | |
| clear_button_aya_vision = gr.ClearButton() | |
| with gr.Column(): | |
| aya_vision_response = gr.Textbox(lines=3,label="Aya Vision's Response", show_copy_button=True, container=True, interactive=False) | |
| lang_textbox = gr.Textbox(visible=False) | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[[lang] for lang in AYA_VISION_PROMPT_EXAMPLES.keys()], | |
| inputs=lang_textbox, | |
| outputs=[aya_vision_prompt, aya_vision_input_img], | |
| fn=get_aya_vision_prompt_example, | |
| label="Load example prompt for:", | |
| examples_per_page=25, | |
| run_on_click=True | |
| ) | |
| # increase spacing between examples and Accordion components | |
| with gr.Row(): | |
| pass | |
| with gr.Row(): | |
| pass | |
| with gr.Row(): | |
| pass | |
| with gr.Row(): | |
| with gr.Accordion("See Details", open=False): | |
| gr.Markdown("This space uses [Aya Vision](https://huggingface.co/CohereForAI/aya-vision-32b) for understanding images.") | |
| with gr.TabItem("Chat with Aya") as chat_with_aya: | |
| cid = gr.State("") | |
| token = gr.State(value=None) | |
| with gr.Column(): | |
| with gr.Row(): | |
| chatbot = gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, height=300) | |
| with gr.Row(): | |
| user_message = gr.Textbox(lines=1, placeholder="Ask anything in our 23 languages ...", label="Input", show_label=False) | |
| msg_temp = gr.Textbox(visible=False) | |
| with gr.Row(): | |
| submit_button = gr.Button("Submit",variant="primary") | |
| clear_button = gr.Button("Clear") | |
| history = gr.State([]) | |
| user_message.submit(fn=generate_aya_chat_response, inputs=[user_message, cid, token, history], outputs=[chatbot, history, cid], concurrency_limit=32) | |
| submit_button.click(fn=generate_aya_chat_response, inputs=[user_message, cid, token, history], outputs=[chatbot, history, cid], concurrency_limit=32) | |
| clear_button.click(fn=clear_chat, inputs=None, outputs=[chatbot, history, cid], concurrency_limit=32) | |
| user_message.submit(lambda x: gr.update(value=""), None, [user_message], queue=False) | |
| submit_button.click(lambda x: gr.update(value=""), None, [user_message], queue=False) | |
| clear_button.click(lambda x: gr.update(value=""), None, [user_message], queue=False) | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[[lang] for lang in TEXT_CHAT_EXAMPLES.keys()], | |
| inputs=msg_temp, | |
| outputs=user_message, | |
| fn=choose_chat_examples, | |
| label="Load example prompt for:", | |
| examples_per_page=25, | |
| run_on_click=True | |
| ) | |
| # End to End Testing Pipeline for speak with Aya | |
| with gr.TabItem("Speak with Aya") as speak_with_aya: | |
| with gr.Row(): | |
| with gr.Column(): | |
| e2e_audio_file = gr.Audio(sources="microphone", type="filepath", min_length=None) | |
| e2_audio_submit_button = gr.Button(value="Get Aya's Response", variant="primary") | |
| clear_button_microphone = gr.ClearButton() | |
| gr.Examples( | |
| examples=AUDIO_EXAMPLES, | |
| inputs=e2e_audio_file, | |
| cache_examples=False, | |
| examples_per_page=25, | |
| label="Load example audio for:", | |
| example_labels=AUDIO_EXAMPLES_LABELS, | |
| ) | |
| with gr.Column(): | |
| e2e_audio_file_trans = gr.Textbox(lines=3,label="Your Input", autoscroll=False, show_copy_button=True, interactive=False) | |
| e2e_audio_file_aya_response = gr.Textbox(lines=3,label="Aya's Response", show_copy_button=True, container=True, interactive=False) | |
| e2e_aya_audio_response = gr.Audio(type="filepath", label="Aya's Audio Response") | |
| with gr.Accordion("See Details", open=False): | |
| gr.Markdown("To enable voice interaction with Aya Expanse, this space uses [Whisper large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) and [Groq](https://groq.com/) for STT and [neets.ai](http://neets.ai/) for TTS.") | |
| # Generate Images | |
| with gr.TabItem("Visualize with Aya") as visualize_with_aya: | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_img_prompt = gr.Textbox(placeholder="Ask anything in our 23 languages ...", label="Ask anything about an image", lines=3) | |
| submit_button_img = gr.Button(value="Submit", variant="primary") | |
| clear_button_img = gr.ClearButton() | |
| with gr.Column(): | |
| generated_img = gr.Image(label="Generated Image", interactive=False) | |
| input_prompt_lang = gr.Textbox(visible=False) | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[[lang] for lang in IMG_GEN_PROMPT_EXAMPLES.keys()], | |
| inputs=input_prompt_lang, | |
| outputs=input_img_prompt, | |
| fn=choose_img_prompt_examples, | |
| label="Load example prompt for:", | |
| examples_per_page=25, | |
| run_on_click=True | |
| ) | |
| generated_img_desc = gr.Textbox(label="Image Description generated by Aya", interactive=False, lines=3, visible=False) | |
| # increase spacing between examples and Accordion components | |
| with gr.Row(): | |
| pass | |
| with gr.Row(): | |
| pass | |
| with gr.Row(): | |
| pass | |
| with gr.Row(): | |
| with gr.Accordion("See Details", open=False): | |
| gr.Markdown("This space uses Aya Expanse for translating multilingual prompts and generating detailed image descriptions and [Flux Schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) for Image Generation.") | |
| # Aya Vision | |
| clear_button_aya_vision.click(lambda: None, None, aya_vision_input_img) | |
| clear_button_aya_vision.click(lambda: None, None, aya_vision_prompt) | |
| clear_button_aya_vision.click(lambda: None, None, aya_vision_response) | |
| submit_aya_vision.click( | |
| get_aya_vision_response, | |
| inputs=[aya_vision_prompt, aya_vision_input_img], | |
| outputs=[aya_vision_response] | |
| ) | |
| # Image Generation | |
| clear_button_img.click(lambda: None, None, input_img_prompt) | |
| clear_button_img.click(lambda: None, None, generated_img_desc) | |
| clear_button_img.click(lambda: None, None, generated_img) | |
| submit_button_img.click( | |
| generate_img_prompt, | |
| inputs=[input_img_prompt], | |
| outputs=[generated_img_desc], | |
| ) | |
| generated_img_desc.change( | |
| generate_image, #run_flux, | |
| inputs=[generated_img_desc], | |
| outputs=[generated_img], | |
| show_progress="full", | |
| ) | |
| # Audio Pipeline | |
| clear_button_microphone.click(lambda: None, None, e2e_audio_file) | |
| clear_button_microphone.click(lambda: None, None, e2e_aya_audio_response) | |
| clear_button_microphone.click(lambda: None, None, e2e_audio_file_aya_response) | |
| clear_button_microphone.click(lambda: None, None, e2e_audio_file_trans) | |
| #e2e_audio_file.change( | |
| e2_audio_submit_button.click( | |
| transcribe_and_stream, | |
| inputs=[e2e_audio_file], | |
| outputs=[e2e_audio_file_trans], | |
| show_progress="full", | |
| ).then( | |
| aya_speech_text_response, | |
| inputs=[e2e_audio_file_trans], | |
| outputs=[e2e_audio_file_aya_response], | |
| show_progress="full", | |
| ).then( | |
| convert_text_to_speech, | |
| inputs=[e2e_audio_file_aya_response], | |
| outputs=[e2e_aya_audio_response], | |
| show_progress="full", | |
| ) | |
| demo.load(lambda: secrets.token_hex(16), None, token) | |
| demo.queue(api_open=False, max_size=20, default_concurrency_limit=4).launch(show_api=False, allowed_paths=['/home/user/app']) |