File size: 2,319 Bytes
7265081
 
667d8e1
7265081
0b75077
1f7fe37
7265081
 
 
667d8e1
0b75077
 
7265081
 
1f7fe37
667d8e1
1f7fe37
667d8e1
7265081
667d8e1
0b75077
 
 
 
 
 
 
 
 
 
667d8e1
0b75077
1f7fe37
 
667d8e1
 
1f7fe37
667d8e1
 
 
 
 
0b75077
667d8e1
0b75077
667d8e1
7265081
667d8e1
 
0b75077
 
667d8e1
 
 
0b75077
667d8e1
0b75077
 
7265081
 
 
 
 
 
 
 
 
 
667d8e1
 
 
 
7265081
 
 
 
 
 
 
1cf2a4d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from ctransformers import AutoModelForCausalLM
import os
import shutil
from huggingface_hub import hf_hub_download

app = FastAPI(title="GPT-OSS-20B API")

# Set environment variables
os.environ["HF_HOME"] = "/app/cache/huggingface"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/app/cache/huggingface/hub"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# Model ID and local directory
MODEL_ID = "unsloth/gpt-oss-20b-GGUF"
MODEL_DIR = "/app/gpt-oss-20b"
MODEL_FILE = "gpt-oss-20b.Q4_K_M.gguf"  # Adjust based on actual filename

# Clear cache directory
cache_dir = os.environ["HF_HOME"]
if os.path.exists(cache_dir):
    print(f"Clearing cache directory: {cache_dir}")
    for item in os.listdir(cache_dir):
        item_path = os.path.join(cache_dir, item)
        if os.path.isdir(item_path):
            shutil.rmtree(item_path, ignore_errors=True)
        else:
            os.remove(item_path) if os.path.exists(item_path) else None

# Create directories
os.makedirs(cache_dir, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

# Download model file
print("Downloading model file...")
try:
    hf_hub_download(
        repo_id=MODEL_ID,
        filename=MODEL_FILE,
        local_dir=MODEL_DIR,
        cache_dir=cache_dir
    )
    print("Model file downloaded successfully.")
except Exception as e:
    raise RuntimeError(f"Failed to download model: {str(e)}")

# Load model
print("Loading model...")
try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        model_type="gguf",
        model_file=MODEL_FILE
    )
    print("Model loaded successfully.")
except Exception as e:
    raise RuntimeError(f"Failed to load model: {str(e)}")

class ChatRequest(BaseModel):
    message: str
    max_tokens: int = 256
    temperature: float = 0.7

@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
    try:
        # Generate response
        response = model(
            request.message,
            max_new_tokens=request.max_tokens,
            temperature=request.temperature
        )
        return {"response": response}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)