Spaces:
Running
Running
| import argparse | |
| import uvicorn | |
| import sys | |
| import json | |
| import string | |
| import random | |
| import base64 | |
| from fastapi import FastAPI, Response | |
| from fastapi.encoders import jsonable_encoder | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| from sse_starlette.sse import EventSourceResponse | |
| from utils.logger import logger | |
| from networks.message_streamer import MessageStreamer | |
| from messagers.message_composer import MessageComposer | |
| from googletrans import Translator | |
| from io import BytesIO | |
| from gtts import gTTS | |
| from fastapi.middleware.cors import CORSMiddleware | |
| class ChatAPIApp: | |
| def __init__(self): | |
| self.app = FastAPI( | |
| docs_url="/", | |
| title="HuggingFace LLM API", | |
| swagger_ui_parameters={"defaultModelsExpandDepth": -1}, | |
| version="1.0", | |
| ) | |
| self.setup_routes() | |
| def get_available_models(self): | |
| f = open('apis/lang_name.json', "r") | |
| self.available_models = json.loads(f.read()) | |
| return self.available_models | |
| class ChatCompletionsPostItem(BaseModel): | |
| from_language: str = Field( | |
| default="auto", | |
| description="(str) `Detect`", | |
| ) | |
| to_language: str = Field( | |
| default="en", | |
| description="(str) `en`", | |
| ) | |
| input_text: str = Field( | |
| default="Hello", | |
| description="(str) `Text for translate`", | |
| ) | |
| def chat_completions(self, item: ChatCompletionsPostItem): | |
| translator = Translator() | |
| f = open('apis/lang_name.json', "r") | |
| available_langs = json.loads(f.read()) | |
| from_lang = 'en' | |
| to_lang = 'en' | |
| for lang_item in available_langs: | |
| if item.to_language == lang_item['code']: | |
| to_lang = item.to_language | |
| break | |
| translated = translator.translate(item.input_text, dest=to_lang) | |
| item_response = { | |
| "from_language": translated.src, | |
| "to_language": translated.dest, | |
| "text": item.input_text, | |
| "translate": translated.text | |
| } | |
| json_compatible_item_data = jsonable_encoder(item_response) | |
| return JSONResponse(content=json_compatible_item_data) | |
| class DetectLanguagePostItem(BaseModel): | |
| input_text: str = Field( | |
| default="Hello", | |
| description="(str) `Text for detection`", | |
| ) | |
| def detect_language(self, item: DetectLanguagePostItem): | |
| translator = Translator() | |
| detected = translator.detect(item.input_text) | |
| item_response = { | |
| "lang": detected.lang, | |
| "confidence": detected.confidence, | |
| } | |
| json_compatible_item_data = jsonable_encoder(item_response) | |
| return JSONResponse(content=json_compatible_item_data) | |
| class TTSPostItem(BaseModel): | |
| input_text: str = Field( | |
| default="Hello", | |
| description="(str) `Text for TTS`", | |
| ) | |
| from_language: str = Field( | |
| default="en", | |
| description="(str) `TTS language`", | |
| ) | |
| def text_to_speech(self, item: TTSPostItem): | |
| try: | |
| audioobj = gTTS(text = item.input_text, lang = item.from_language, slow = False) | |
| fileName = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(10)); | |
| fileName = fileName + ".mp3"; | |
| mp3_fp = BytesIO() | |
| #audioobj.save(fileName) | |
| audioobj.write_to_fp(mp3_fp) | |
| buffer = bytearray(mp3_fp.read()) | |
| #base64EncodedStr = base64.encodebytes(buffer) | |
| #mp3_fp.read() | |
| #return Response(content=mp3_fp.tell(), media_type="audio/mpeg") | |
| return buffer | |
| except: | |
| item_response = { | |
| "status": 400 | |
| } | |
| json_compatible_item_data = jsonable_encoder(item_response) | |
| return JSONResponse(content=json_compatible_item_data) | |
| def setup_routes(self): | |
| for prefix in ["", "/v1"]: | |
| self.app.get( | |
| prefix + "/models", | |
| summary="Get available languages", | |
| )(self.get_available_models) | |
| self.app.post( | |
| prefix + "/translate", | |
| summary="translate text", | |
| )(self.chat_completions) | |
| self.app.post( | |
| prefix + "/detect", | |
| summary="detect language", | |
| )(self.detect_language) | |
| self.app.post( | |
| prefix + "/tts", | |
| summary="text to speech", | |
| )(self.text_to_speech) | |
| class ArgParser(argparse.ArgumentParser): | |
| def __init__(self, *args, **kwargs): | |
| super(ArgParser, self).__init__(*args, **kwargs) | |
| self.add_argument( | |
| "-s", | |
| "--server", | |
| type=str, | |
| default="0.0.0.0", | |
| help="Server IP for HF LLM Chat API", | |
| ) | |
| self.add_argument( | |
| "-p", | |
| "--port", | |
| type=int, | |
| default=23333, | |
| help="Server Port for HF LLM Chat API", | |
| ) | |
| self.add_argument( | |
| "-d", | |
| "--dev", | |
| default=False, | |
| action="store_true", | |
| help="Run in dev mode", | |
| ) | |
| self.args = self.parse_args(sys.argv[1:]) | |
| app = ChatAPIApp().app | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| if __name__ == "__main__": | |
| args = ArgParser().args | |
| if args.dev: | |
| uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) | |
| else: | |
| uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) | |
| # python -m apis.chat_api # [Docker] on product mode | |
| # python -m apis.chat_api -d # [Dev] on develop mode | |