rag / app.py
aiqtech's picture
Update app.py
b3f1dd2 verified
raw
history blame
10.4 kB
import os
import json
import asyncio
from typing import Optional, List, Dict
from contextlib import asynccontextmanager
import requests
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import gradio as gr
# Pydantic λͺ¨λΈ μ •μ˜
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
model: str = "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507"
max_tokens: int = Field(default=4096, ge=1, le=8192)
temperature: float = Field(default=0.6, ge=0, le=2)
top_p: float = Field(default=1.0, ge=0, le=1)
top_k: int = Field(default=40, ge=1, le=100)
presence_penalty: float = Field(default=0, ge=-2, le=2)
frequency_penalty: float = Field(default=0, ge=-2, le=2)
class ChatResponse(BaseModel):
response: str
model: str
tokens_used: Optional[int] = None
# Fireworks API ν΄λΌμ΄μ–ΈνŠΈ
class FireworksClient:
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or os.getenv("FIREWORKS_API_KEY")
if not self.api_key:
raise ValueError("API key is required. Set FIREWORKS_API_KEY environment variable.")
self.base_url = "https://api.fireworks.ai/inference/v1/chat/completions"
self.headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
def chat(self, request: ChatRequest) -> Dict:
"""Fireworks API에 μ±„νŒ… μš”μ²­μ„ λ³΄λƒ…λ‹ˆλ‹€."""
payload = {
"model": request.model,
"max_tokens": request.max_tokens,
"top_p": request.top_p,
"top_k": request.top_k,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"temperature": request.temperature,
"messages": [msg.dict() for msg in request.messages]
}
try:
response = requests.post(
self.base_url,
headers=self.headers,
data=json.dumps(payload),
timeout=30
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=500, detail=f"API request failed: {str(e)}")
# Gradio μ•± 생성
def create_gradio_app(client: FireworksClient):
"""Gradio μΈν„°νŽ˜μ΄μŠ€λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€."""
def chat_with_llm(
message: str,
history: List[List[str]],
model: str,
temperature: float,
max_tokens: int,
top_p: float,
top_k: int
):
"""Gradio μ±„νŒ… ν•¨μˆ˜"""
if not message:
return "", history
# λŒ€ν™” 기둝을 Message ν˜•μ‹μœΌλ‘œ λ³€ν™˜
messages = []
for user_msg, assistant_msg in history:
if user_msg:
messages.append(Message(role="user", content=user_msg))
if assistant_msg:
messages.append(Message(role="assistant", content=assistant_msg))
# ν˜„μž¬ λ©”μ‹œμ§€ μΆ”κ°€
messages.append(Message(role="user", content=message))
# API μš”μ²­
try:
request = ChatRequest(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
top_k=top_k
)
response = client.chat(request)
# μ‘λ‹΅μ—μ„œ ν…μŠ€νŠΈ μΆ”μΆœ
if "choices" in response and len(response["choices"]) > 0:
assistant_response = response["choices"][0]["message"]["content"]
else:
assistant_response = "응닡을 받을 수 μ—†μŠ΅λ‹ˆλ‹€."
# νžˆμŠ€ν† λ¦¬ μ—…λ°μ΄νŠΈ
history.append([message, assistant_response])
return "", history
except Exception as e:
error_msg = f"였λ₯˜ λ°œμƒ: {str(e)}"
history.append([message, error_msg])
return "", history
# Gradio μΈν„°νŽ˜μ΄μŠ€ ꡬ성
with gr.Blocks(title="LLM Chat Interface") as demo:
gr.Markdown("# πŸš€ Fireworks LLM Chat Interface")
gr.Markdown("Qwen3-235B λͺ¨λΈμ„ μ‚¬μš©ν•œ μ±„νŒ… μΈν„°νŽ˜μ΄μŠ€μž…λ‹ˆλ‹€.")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
height=500,
label="μ±„νŒ… μ°½"
)
msg = gr.Textbox(
label="λ©”μ‹œμ§€ μž…λ ₯",
placeholder="λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•˜μ„Έμš”...",
lines=2
)
with gr.Row():
submit = gr.Button("전솑", variant="primary")
clear = gr.Button("λŒ€ν™” μ΄ˆκΈ°ν™”")
with gr.Column(scale=1):
gr.Markdown("### βš™οΈ μ„€μ •")
model = gr.Textbox(
label="λͺ¨λΈ",
value="accounts/fireworks/models/qwen3-235b-a22b-instruct-2507",
interactive=True
)
temperature = gr.Slider(
minimum=0,
maximum=2,
value=0.6,
step=0.1,
label="Temperature"
)
max_tokens = gr.Slider(
minimum=100,
maximum=8192,
value=4096,
step=100,
label="Max Tokens"
)
top_p = gr.Slider(
minimum=0,
maximum=1,
value=1.0,
step=0.1,
label="Top P"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=40,
step=1,
label="Top K"
)
# 이벀트 ν•Έλ“€λŸ¬
submit.click(
chat_with_llm,
inputs=[msg, chatbot, model, temperature, max_tokens, top_p, top_k],
outputs=[msg, chatbot]
)
msg.submit(
chat_with_llm,
inputs=[msg, chatbot, model, temperature, max_tokens, top_p, top_k],
outputs=[msg, chatbot]
)
clear.click(lambda: None, None, chatbot, queue=False)
return demo
# FastAPI μ•± μ„€μ •
@asynccontextmanager
async def lifespan(app: FastAPI):
"""μ•± μ‹œμž‘/μ’…λ£Œ μ‹œ μ‹€ν–‰λ˜λŠ” ν•¨μˆ˜"""
# μ‹œμž‘ μ‹œ
print("πŸš€ Starting FastAPI + Gradio server...")
yield
# μ’…λ£Œ μ‹œ
print("πŸ‘‹ Shutting down server...")
app = FastAPI(
title="LLM API with Gradio Interface",
description="Fireworks LLM API with Gradio testing interface",
version="1.0.0",
lifespan=lifespan
)
# CORS μ„€μ •
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Fireworks ν΄λΌμ΄μ–ΈνŠΈ μ΄ˆκΈ°ν™”
try:
fireworks_client = FireworksClient()
except ValueError as e:
print(f"⚠️ Warning: {e}")
print("API endpoints will not work without a valid API key.")
fireworks_client = None
# API μ—”λ“œν¬μΈνŠΈ
@app.get("/")
async def root():
"""루트 μ—”λ“œν¬μΈνŠΈ"""
return {
"message": "LLM API Server is running",
"endpoints": {
"api": "/chat",
"gradio": "/gradio",
"docs": "/docs"
}
}
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
"""μ±„νŒ… API μ—”λ“œν¬μΈνŠΈ"""
if not fireworks_client:
raise HTTPException(status_code=500, detail="API key not configured")
try:
response = fireworks_client.chat(request)
# 응닡 νŒŒμ‹±
if "choices" in response and len(response["choices"]) > 0:
content = response["choices"][0]["message"]["content"]
tokens = response.get("usage", {}).get("total_tokens")
return ChatResponse(
response=content,
model=request.model,
tokens_used=tokens
)
else:
raise HTTPException(status_code=500, detail="Invalid response from API")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""ν—¬μŠ€ 체크 μ—”λ“œν¬μΈνŠΈ"""
return {
"status": "healthy",
"api_configured": fireworks_client is not None
}
# Gradio μ•± 마운트
if fireworks_client:
gradio_app = create_gradio_app(fireworks_client)
app = gr.mount_gradio_app(app, gradio_app, path="/gradio")
# 메인 μ‹€ν–‰
if __name__ == "__main__":
import sys
# API ν‚€ 확인
if not os.getenv("FIREWORKS_API_KEY"):
print("⚠️ κ²½κ³ : FIREWORKS_API_KEY ν™˜κ²½λ³€μˆ˜κ°€ μ„€μ •λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€.")
print("μ„€μ • 방법:")
print(" Linux/Mac: export FIREWORKS_API_KEY='your-api-key'")
print(" Windows: set FIREWORKS_API_KEY=your-api-key")
print("")
# μ„ νƒμ μœΌλ‘œ API ν‚€ μž…λ ₯λ°›κΈ°
api_key = input("API ν‚€λ₯Ό μž…λ ₯ν•˜μ„Έμš” (Enterλ₯Ό λˆ„λ₯΄λ©΄ κ±΄λ„ˆλœλ‹ˆλ‹€): ").strip()
if api_key:
os.environ["FIREWORKS_API_KEY"] = api_key
fireworks_client = FireworksClient(api_key)
gradio_app = create_gradio_app(fireworks_client)
app = gr.mount_gradio_app(app, gradio_app, path="/gradio")
# μ„œλ²„ μ‹œμž‘
print("\nπŸš€ μ„œλ²„λ₯Ό μ‹œμž‘ν•©λ‹ˆλ‹€...")
print("πŸ“ API λ¬Έμ„œ: http://localhost:7860/docs")
print("πŸ’¬ Gradio UI: http://localhost:7860/gradio")
print("πŸ”§ API μ—”λ“œν¬μΈνŠΈ: http://localhost:7860/chat")
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
reload=False
)