lite_DETECTIVE / app.py
AlbertCAC's picture
update
25ba0c9
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from typing import Optional
from cold.classifier import ToxicTextClassifier
import torch
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model = ToxicTextClassifier()
model.load_state_dict(torch.load("output/lited_best.pth",map_location="cpu"))
class PredictionInput(BaseModel):
text: str = Field(..., title="Text to classify", description="The text to classify for malicious content")
context: Optional[str] = Field(None, title="Context for classification", description="Optional context to provide additional information for classification")
@app.post("/predict")
def predict(input: PredictionInput):
try:
if not input.text:
raise HTTPException(status_code=400, detail="Text input is required")
elif len(input.text) > 512:
raise HTTPException(status_code=400, detail="Text input exceeds maximum length of 512 characters")
if input.context and len(input.context) > 512:
raise HTTPException(status_code=400, detail="Context input exceeds maximum length of 512 characters")
if not input.context:
result = model.predict(input.text, device="cpu")
print(result)
return {"text": input.text, "prediction": result[0]["prediction"], "probabilities": result[0]["probabilities"]}
else:
result = model.predict([[input.text,input.context]], device="cpu")
return {"text": input.text, "context": input.context, "prediction": result[0]["prediction"], "probabilities": result[0]["probabilities"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
app.mount("/", StaticFiles(directory="out", html=True), name="static")