|
|
import os |
|
|
import json |
|
|
import asyncio |
|
|
from typing import Optional, List, Dict |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
import requests |
|
|
import uvicorn |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
class Message(BaseModel): |
|
|
role: str |
|
|
content: str |
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
messages: List[Message] |
|
|
model: str = "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507" |
|
|
max_tokens: int = Field(default=4096, ge=1, le=8192) |
|
|
temperature: float = Field(default=0.6, ge=0, le=2) |
|
|
top_p: float = Field(default=1.0, ge=0, le=1) |
|
|
top_k: int = Field(default=40, ge=1, le=100) |
|
|
presence_penalty: float = Field(default=0, ge=-2, le=2) |
|
|
frequency_penalty: float = Field(default=0, ge=-2, le=2) |
|
|
|
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
response: str |
|
|
model: str |
|
|
tokens_used: Optional[int] = None |
|
|
|
|
|
|
|
|
|
|
|
class FireworksClient: |
|
|
def __init__(self, api_key: Optional[str] = None): |
|
|
self.api_key = api_key or os.getenv("FIREWORKS_API_KEY") |
|
|
if not self.api_key: |
|
|
raise ValueError("API key is required. Set FIREWORKS_API_KEY environment variable.") |
|
|
|
|
|
self.base_url = "https://api.fireworks.ai/inference/v1/chat/completions" |
|
|
self.headers = { |
|
|
"Accept": "application/json", |
|
|
"Content-Type": "application/json", |
|
|
"Authorization": f"Bearer {self.api_key}" |
|
|
} |
|
|
|
|
|
def chat(self, request: ChatRequest) -> Dict: |
|
|
"""Fireworks APIμ μ±ν
μμ²μ 보λ
λλ€.""" |
|
|
payload = { |
|
|
"model": request.model, |
|
|
"max_tokens": request.max_tokens, |
|
|
"top_p": request.top_p, |
|
|
"top_k": request.top_k, |
|
|
"presence_penalty": request.presence_penalty, |
|
|
"frequency_penalty": request.frequency_penalty, |
|
|
"temperature": request.temperature, |
|
|
"messages": [msg.dict() for msg in request.messages] |
|
|
} |
|
|
|
|
|
try: |
|
|
response = requests.post( |
|
|
self.base_url, |
|
|
headers=self.headers, |
|
|
data=json.dumps(payload), |
|
|
timeout=30 |
|
|
) |
|
|
response.raise_for_status() |
|
|
return response.json() |
|
|
except requests.exceptions.RequestException as e: |
|
|
raise HTTPException(status_code=500, detail=f"API request failed: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
def create_gradio_app(client: FireworksClient): |
|
|
"""Gradio μΈν°νμ΄μ€λ₯Ό μμ±ν©λλ€.""" |
|
|
|
|
|
def chat_with_llm( |
|
|
message: str, |
|
|
history: List[List[str]], |
|
|
model: str, |
|
|
temperature: float, |
|
|
max_tokens: int, |
|
|
top_p: float, |
|
|
top_k: int |
|
|
): |
|
|
"""Gradio μ±ν
ν¨μ""" |
|
|
if not message: |
|
|
return "", history |
|
|
|
|
|
|
|
|
messages = [] |
|
|
for user_msg, assistant_msg in history: |
|
|
if user_msg: |
|
|
messages.append(Message(role="user", content=user_msg)) |
|
|
if assistant_msg: |
|
|
messages.append(Message(role="assistant", content=assistant_msg)) |
|
|
|
|
|
|
|
|
messages.append(Message(role="user", content=message)) |
|
|
|
|
|
|
|
|
try: |
|
|
request = ChatRequest( |
|
|
messages=messages, |
|
|
model=model, |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens, |
|
|
top_p=top_p, |
|
|
top_k=top_k |
|
|
) |
|
|
|
|
|
response = client.chat(request) |
|
|
|
|
|
|
|
|
if "choices" in response and len(response["choices"]) > 0: |
|
|
assistant_response = response["choices"][0]["message"]["content"] |
|
|
else: |
|
|
assistant_response = "μλ΅μ λ°μ μ μμ΅λλ€." |
|
|
|
|
|
|
|
|
history.append([message, assistant_response]) |
|
|
return "", history |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"μ€λ₯ λ°μ: {str(e)}" |
|
|
history.append([message, error_msg]) |
|
|
return "", history |
|
|
|
|
|
|
|
|
with gr.Blocks(title="LLM Chat Interface") as demo: |
|
|
gr.Markdown("# π Fireworks LLM Chat Interface") |
|
|
gr.Markdown("Qwen3-235B λͺ¨λΈμ μ¬μ©ν μ±ν
μΈν°νμ΄μ€μ
λλ€.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
chatbot = gr.Chatbot( |
|
|
height=500, |
|
|
label="μ±ν
μ°½" |
|
|
) |
|
|
msg = gr.Textbox( |
|
|
label="λ©μμ§ μ
λ ₯", |
|
|
placeholder="λ©μμ§λ₯Ό μ
λ ₯νμΈμ...", |
|
|
lines=2 |
|
|
) |
|
|
with gr.Row(): |
|
|
submit = gr.Button("μ μ‘", variant="primary") |
|
|
clear = gr.Button("λν μ΄κΈ°ν") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### βοΈ μ€μ ") |
|
|
model = gr.Textbox( |
|
|
label="λͺ¨λΈ", |
|
|
value="accounts/fireworks/models/qwen3-235b-a22b-instruct-2507", |
|
|
interactive=True |
|
|
) |
|
|
temperature = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=2, |
|
|
value=0.6, |
|
|
step=0.1, |
|
|
label="Temperature" |
|
|
) |
|
|
max_tokens = gr.Slider( |
|
|
minimum=100, |
|
|
maximum=8192, |
|
|
value=4096, |
|
|
step=100, |
|
|
label="Max Tokens" |
|
|
) |
|
|
top_p = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=1, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
label="Top P" |
|
|
) |
|
|
top_k = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=100, |
|
|
value=40, |
|
|
step=1, |
|
|
label="Top K" |
|
|
) |
|
|
|
|
|
|
|
|
submit.click( |
|
|
chat_with_llm, |
|
|
inputs=[msg, chatbot, model, temperature, max_tokens, top_p, top_k], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
msg.submit( |
|
|
chat_with_llm, |
|
|
inputs=[msg, chatbot, model, temperature, max_tokens, top_p, top_k], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""μ± μμ/μ’
λ£ μ μ€νλλ ν¨μ""" |
|
|
|
|
|
print("π Starting FastAPI + Gradio server...") |
|
|
yield |
|
|
|
|
|
print("π Shutting down server...") |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="LLM API with Gradio Interface", |
|
|
description="Fireworks LLM API with Gradio testing interface", |
|
|
version="1.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
fireworks_client = FireworksClient() |
|
|
except ValueError as e: |
|
|
print(f"β οΈ Warning: {e}") |
|
|
print("API endpoints will not work without a valid API key.") |
|
|
fireworks_client = None |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""λ£¨νΈ μλν¬μΈνΈ""" |
|
|
return { |
|
|
"message": "LLM API Server is running", |
|
|
"endpoints": { |
|
|
"api": "/chat", |
|
|
"gradio": "/gradio", |
|
|
"docs": "/docs" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/chat", response_model=ChatResponse) |
|
|
async def chat(request: ChatRequest): |
|
|
"""μ±ν
API μλν¬μΈνΈ""" |
|
|
if not fireworks_client: |
|
|
raise HTTPException(status_code=500, detail="API key not configured") |
|
|
|
|
|
try: |
|
|
response = fireworks_client.chat(request) |
|
|
|
|
|
|
|
|
if "choices" in response and len(response["choices"]) > 0: |
|
|
content = response["choices"][0]["message"]["content"] |
|
|
tokens = response.get("usage", {}).get("total_tokens") |
|
|
|
|
|
return ChatResponse( |
|
|
response=content, |
|
|
model=request.model, |
|
|
tokens_used=tokens |
|
|
) |
|
|
else: |
|
|
raise HTTPException(status_code=500, detail="Invalid response from API") |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""ν¬μ€ μ²΄ν¬ μλν¬μΈνΈ""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"api_configured": fireworks_client is not None |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if fireworks_client: |
|
|
gradio_app = create_gradio_app(fireworks_client) |
|
|
app = gr.mount_gradio_app(app, gradio_app, path="/gradio") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
|
|
|
|
|
|
if not os.getenv("FIREWORKS_API_KEY"): |
|
|
print("β οΈ κ²½κ³ : FIREWORKS_API_KEY νκ²½λ³μκ° μ€μ λμ§ μμμ΅λλ€.") |
|
|
print("μ€μ λ°©λ²:") |
|
|
print(" Linux/Mac: export FIREWORKS_API_KEY='your-api-key'") |
|
|
print(" Windows: set FIREWORKS_API_KEY=your-api-key") |
|
|
print("") |
|
|
|
|
|
|
|
|
api_key = input("API ν€λ₯Ό μ
λ ₯νμΈμ (Enterλ₯Ό λλ₯΄λ©΄ 건λλλλ€): ").strip() |
|
|
if api_key: |
|
|
os.environ["FIREWORKS_API_KEY"] = api_key |
|
|
fireworks_client = FireworksClient(api_key) |
|
|
gradio_app = create_gradio_app(fireworks_client) |
|
|
app = gr.mount_gradio_app(app, gradio_app, path="/gradio") |
|
|
|
|
|
|
|
|
print("\nπ μλ²λ₯Ό μμν©λλ€...") |
|
|
print("π API λ¬Έμ: http://localhost:7860/docs") |
|
|
print("π¬ Gradio UI: http://localhost:7860/gradio") |
|
|
print("π§ API μλν¬μΈνΈ: http://localhost:7860/chat") |
|
|
|
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=7860, |
|
|
reload=False |
|
|
) |