Spaces:
Runtime error
Runtime error
File size: 4,449 Bytes
7265081 0b75077 0adb580 1f7fe37 7265081 0b75077 7265081 1f7fe37 7265081 1f7fe37 7265081 0b75077 0adb580 0b75077 1f7fe37 0adb580 1f7fe37 0b75077 0adb580 7265081 0b75077 1cf2a4d 0b75077 7265081 0b75077 0adb580 0b75077 1f7fe37 0b75077 7265081 1cf2a4d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import shutil
import json
from huggingface_hub import hf_hub_download
app = FastAPI(title="GPT-OSS-20B API")
# Set environment variables for Hugging Face cache
os.environ["HF_HOME"] = "/app/cache/huggingface"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/app/cache/huggingface/hub"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Model ID and local directory
MODEL_ID = "openai/gpt-oss-20b"
MODEL_DIR = "/app/gpt-oss-20b"
# Clear cache directory if lock files exist
cache_dir = os.environ["HF_HOME"]
if os.path.exists(cache_dir):
print(f"Clearing cache directory: {cache_dir}")
for item in os.listdir(cache_dir):
item_path = os.path.join(cache_dir, item)
if os.path.isdir(item_path):
shutil.rmtree(item_path, ignore_errors=True)
else:
os.remove(item_path) if os.path.exists(item_path) else None
# Create cache and model directories
os.makedirs(cache_dir, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
# Download model files
print("Downloading model files...")
try:
for file in ["config.json", "dtypes.json", "model.safetensors"]:
hf_hub_download(
repo_id=MODEL_ID,
filename=f"original/{file}",
local_dir=MODEL_DIR,
cache_dir=cache_dir
)
print("Model files downloaded successfully.")
except Exception as e:
raise RuntimeError(f"Failed to download model files: {str(e)}")
# Fix config.json if model_type is missing
config_path = os.path.join(MODEL_DIR, "original/config.json")
try:
with open(config_path, "r") as f:
config = json.load(f)
if "model_type" not in config or config["model_type"] != "gpt_oss":
print("Fixing config.json: setting model_type to 'gpt_oss'")
config["model_type"] = "gpt_oss"
with open(config_path, "w") as f:
json.dump(config, f, indent=2)
except Exception as e:
print(f"Warning: Failed to check or fix config.json: {str(e)}")
# Load tokenizer
print("Loading tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID, # Load directly from Hub
cache_dir=cache_dir,
trust_remote_code=True
)
except Exception as e:
raise RuntimeError(f"Failed to load tokenizer: {str(e)}")
# Load model with CPU offloading
print("Loading model (this may take several minutes)...")
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, # Load directly from Hub
cache_dir=cache_dir,
device_map="auto", # Automatically place on CPU
torch_dtype="auto", # Automatic precision
offload_folder="/app/offload", # Offload weights to disk
max_memory={0: "14GB", "cpu": "15GB"}, # Adjusted memory constraints
trust_remote_code=True
)
print(f"Model loaded on: {model.device}")
print(f"Model dtype: {model.dtype}")
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
# Enable gradient checkpointing to reduce memory usage
model.gradient_checkpointing_enable()
class ChatRequest(BaseModel):
message: str
max_tokens: int = 256
temperature: float = 0.7
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
try:
# Prepare input
messages = [{"role": "user", "content": request.message}]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True
).to("cpu")
# Generate response
with torch.no_grad():
generated = model.generate(
**inputs,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
# Decode response
response = tokenizer.decode(
generated[0][inputs["input_ids"].shape[-1]:],
skip_special_tokens=True
)
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Clear cache regularly to manage memory
torch.cuda.empty_cache()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |