gpt-oss-20b-cpu / app.py
dbmoradi60's picture
Update app.py
667d8e1 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from ctransformers import AutoModelForCausalLM
import os
import shutil
from huggingface_hub import hf_hub_download
app = FastAPI(title="GPT-OSS-20B API")
# Set environment variables
os.environ["HF_HOME"] = "/app/cache/huggingface"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/app/cache/huggingface/hub"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Model ID and local directory
MODEL_ID = "unsloth/gpt-oss-20b-GGUF"
MODEL_DIR = "/app/gpt-oss-20b"
MODEL_FILE = "gpt-oss-20b.Q4_K_M.gguf" # Adjust based on actual filename
# Clear cache directory
cache_dir = os.environ["HF_HOME"]
if os.path.exists(cache_dir):
print(f"Clearing cache directory: {cache_dir}")
for item in os.listdir(cache_dir):
item_path = os.path.join(cache_dir, item)
if os.path.isdir(item_path):
shutil.rmtree(item_path, ignore_errors=True)
else:
os.remove(item_path) if os.path.exists(item_path) else None
# Create directories
os.makedirs(cache_dir, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
# Download model file
print("Downloading model file...")
try:
hf_hub_download(
repo_id=MODEL_ID,
filename=MODEL_FILE,
local_dir=MODEL_DIR,
cache_dir=cache_dir
)
print("Model file downloaded successfully.")
except Exception as e:
raise RuntimeError(f"Failed to download model: {str(e)}")
# Load model
print("Loading model...")
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_DIR,
model_type="gguf",
model_file=MODEL_FILE
)
print("Model loaded successfully.")
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
class ChatRequest(BaseModel):
message: str
max_tokens: int = 256
temperature: float = 0.7
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
try:
# Generate response
response = model(
request.message,
max_new_tokens=request.max_tokens,
temperature=request.temperature
)
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)