Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel, Field | |
| from typing import Optional | |
| from cold.classifier import ToxicTextClassifier | |
| import torch | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| model = ToxicTextClassifier() | |
| model.load_state_dict(torch.load("output/lited_best.pth",map_location="cpu")) | |
| class PredictionInput(BaseModel): | |
| text: str = Field(..., title="Text to classify", description="The text to classify for malicious content") | |
| context: Optional[str] = Field(None, title="Context for classification", description="Optional context to provide additional information for classification") | |
| def predict(input: PredictionInput): | |
| try: | |
| if not input.text: | |
| raise HTTPException(status_code=400, detail="Text input is required") | |
| elif len(input.text) > 512: | |
| raise HTTPException(status_code=400, detail="Text input exceeds maximum length of 512 characters") | |
| if input.context and len(input.context) > 512: | |
| raise HTTPException(status_code=400, detail="Context input exceeds maximum length of 512 characters") | |
| if not input.context: | |
| result = model.predict(input.text, device="cpu") | |
| print(result) | |
| return {"text": input.text, "prediction": result[0]["prediction"], "probabilities": result[0]["probabilities"]} | |
| else: | |
| result = model.predict([[input.text,input.context]], device="cpu") | |
| return {"text": input.text, "context": input.context, "prediction": result[0]["prediction"], "probabilities": result[0]["probabilities"]} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| app.mount("/", StaticFiles(directory="out", html=True), name="static") | |