Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| from langchain_core.pydantic_v1 import BaseModel, Field | |
| from fastapi import FastAPI, HTTPException, Query, Request | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from langchain.chains import LLMChain | |
| from langchain.prompts import PromptTemplate | |
| from TextGen.suno import custom_generate_audio, get_audio_information | |
| from TextGen.coqui import predict | |
| from langchain_google_genai import ( | |
| ChatGoogleGenerativeAI, | |
| HarmBlockThreshold, | |
| HarmCategory, | |
| ) | |
| from TextGen import app | |
| from gradio_client import Client, handle_file | |
| from typing import List | |
| class PlayLastMusic(BaseModel): | |
| '''plays the lastest created music ''' | |
| Desicion: str = Field( | |
| ..., description="Yes or No" | |
| ) | |
| class CreateLyrics(BaseModel): | |
| f'''create some Lyrics for a new music''' | |
| Desicion: str = Field( | |
| ..., description="Yes or No" | |
| ) | |
| class CreateNewMusic(BaseModel): | |
| f'''create a new music with the Lyrics previously computed''' | |
| Name: str = Field( | |
| ..., description="tags to describe the new music" | |
| ) | |
| class Message(BaseModel): | |
| npc: str | None = None | |
| messages: List[str] | None = None | |
| class VoiceMessage(BaseModel): | |
| npc: str | None = None | |
| input: str | None = None | |
| language: str | None = "en" | |
| genre:str | None = "Male" | |
| song_base_api=os.environ["VERCEL_API"] | |
| my_hf_token=os.environ["HF_TOKEN"] | |
| tts_client = Client("Jofthomas/xtts",hf_token=my_hf_token) | |
| main_npcs={ | |
| "Blacksmith":"./voices/Blacksmith.mp3", | |
| "Herbalist":"./voices/female.mp3", | |
| "Bard":"./voices/Bard_voice.mp3" | |
| } | |
| main_npc_system_prompts={ | |
| "Blacksmith":"You are a blacksmith in a video game", | |
| "Herbalist":"You are an herbalist in a video game", | |
| "Bard":"You are a bard in a video game" | |
| } | |
| class Generate(BaseModel): | |
| text:str | |
| def generate_text(messages: List[str], npc:str): | |
| print(npc) | |
| if npc in main_npcs: | |
| system_prompt=main_npc_system_prompts[npc] | |
| else: | |
| system_prompt="you're a character in a video game. Play along." | |
| print(system_prompt) | |
| new_messages=[{"role": "user", "content": system_prompt}] | |
| for index, message in enumerate(messages): | |
| if index%2==0: | |
| new_messages.append({"role": "user", "content": message}) | |
| else: | |
| new_messages.append({"role": "assistant", "content": message}) | |
| print(new_messages) | |
| # Initialize the LLM | |
| llm = ChatGoogleGenerativeAI( | |
| model="gemini-1.5-pro-latest", | |
| max_output_tokens=100, | |
| temperature=1, | |
| safety_settings={ | |
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, | |
| HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE | |
| }, | |
| ) | |
| if npc=="bard": | |
| llm = llm.bind_tools([PlayLastMusic,CreateNewMusic,CreateLyrics]) | |
| llm_response = llm.invoke(new_messages) | |
| print(llm_response) | |
| return Generate(text=llm_response.content) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def api_home(): | |
| return {'detail': 'Everchanging Quest backend, nothing to see here'} | |
| def inference(message: Message): | |
| return generate_text(messages=message.messages, npc=message.npc) | |
| #Dummy function for now | |
| def determine_vocie_from_npc(npc,genre): | |
| if npc in main_npcs: | |
| return main_npcs[npc] | |
| else: | |
| if genre =="Male": | |
| "./voices/default_male.mp3" | |
| if genre=="Female": | |
| return"./voices/default_female.mp3" | |
| else: | |
| return "./voices/narator_out.wav" | |
| async def generate_wav(message: VoiceMessage): | |
| try: | |
| voice = determine_vocie_from_npc(message.npc, message.genre) | |
| audio_file_pth = handle_file(voice) | |
| # Generator function to yield audio chunks | |
| async def audio_stream(): | |
| result = tts_client.predict( | |
| prompt=message.input, | |
| language=message.language, | |
| audio_file_pth=audio_file_pth, | |
| mic_file_path=None, | |
| use_mic=False, | |
| voice_cleanup=False, | |
| no_lang_auto_detect=False, | |
| agree=True, | |
| api_name="/predict" | |
| ) | |
| for sampling_rate, audio_chunk in result: | |
| yield audio_chunk.tobytes() | |
| await asyncio.sleep(0) # Yield control to the event loop | |
| # Return the generated audio as a streaming response | |
| return StreamingResponse(audio_stream(), media_type="audio/wav") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_voice(message: VoiceMessage): | |
| try: | |
| voice = determine_vocie_from_npc(message.npc, message.genre) | |
| audio_file_pth = handle_file(voice) | |
| # Generator function to yield audio chunks | |
| async def audio_stream(): | |
| result = predict( | |
| prompt=message.input, | |
| language=message.language, | |
| audio_file_pth=audio_file_pth, | |
| mic_file_path=None, | |
| use_mic=False, | |
| voice_cleanup=False, | |
| no_lang_auto_detect=False, | |
| agree=True, | |
| ) | |
| for sampling_rate, audio_chunk in result: | |
| yield audio_chunk.tobytes() | |
| await asyncio.sleep(0) # Yield control to the event loop | |
| # Return the generated audio as a streaming response | |
| return StreamingResponse(audio_stream(), media_type="audio/wav") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_song(text: str): | |
| try: | |
| data = custom_generate_audio({ | |
| "prompt": f"{text}", | |
| "make_instrumental": False, | |
| "wait_audio": False | |
| }) | |
| ids = f"{data[0]['id']},{data[1]['id']}" | |
| print(f"ids: {ids}") | |
| for _ in range(60): | |
| data = get_audio_information(ids) | |
| if data[0]["status"] == 'streaming': | |
| print(f"{data[0]['id']} ==> {data[0]['audio_url']}") | |
| print(f"{data[1]['id']} ==> {data[1]['audio_url']}") | |
| break | |
| # sleep 5s | |
| time.sleep(5) | |
| except: | |
| print("Error") |