Pathora / backend /app.py
malavikapradeep2001's picture
Update backend/app.py
6893c2b unverified
raw
history blame
9.36 kB
import os
import shutil
for d in ["/tmp/huggingface", "/tmp/Ultralytics", "/tmp/matplotlib", "/tmp/torch", "/root/.cache"]:
shutil.rmtree(d, ignore_errors=True)
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface"
os.environ["TORCH_HOME"] = "/tmp/torch"
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics"
from huggingface_hub import login
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from ultralytics import YOLO
from io import BytesIO
from PIL import Image
import uvicorn
import json, os, uuid, numpy as np, torch, cv2, joblib, io, tensorflow as tf
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from sklearn.preprocessing import MinMaxScaler
from model import MWT as create_model
from augmentations import Augmentations
from model_histo import BreastCancerClassifier # TensorFlow model
# =====================================================
# App setup
# =====================================================
app = FastAPI(title="Unified Cervical & Breast Cancer Analysis API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
OUTPUT_DIR = "/tmp/outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)
app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =====================================================
# Model 1: YOLO (Colposcopy Detection)
# =====================================================
print("๐Ÿ”น Loading YOLO model...")
yolo_model = YOLO("best2.pt")
# =====================================================
# Model 2: MWT Classifier
# =====================================================
print("๐Ÿ”น Loading MWT model...")
mwt_model = create_model(num_classes=2).to(device)
mwt_model.load_state_dict(torch.load("MWTclass2.pth", map_location=device))
mwt_model.eval()
mwt_class_names = ['Negative', 'Positive']
# =====================================================
# Model 3: CIN Classifier
# =====================================================
print("๐Ÿ”น Loading CIN model...")
clf = joblib.load("logistic_regression_model.pkl")
yolo_colposcopy = YOLO("yolo_colposcopy.pt")
def build_resnet(model_name="resnet50"):
if model_name == "resnet50":
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
elif model_name == "resnet101":
model = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
elif model_name == "resnet152":
model = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
model.eval().to(device)
return (
nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool),
model.layer1, model.layer2, model.layer3, model.layer4,
)
gap = nn.AdaptiveAvgPool2d((1, 1))
gmp = nn.AdaptiveMaxPool2d((1, 1))
resnet50_blocks = build_resnet("resnet50")
resnet101_blocks = build_resnet("resnet101")
resnet152_blocks = build_resnet("resnet152")
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# =====================================================
# Model 4: Histopathology Classifier (TensorFlow)
# =====================================================
print("๐Ÿ”น Loading Breast Cancer Histopathology model...")
classifier = BreastCancerClassifier(fine_tune=False)
if not classifier.authenticate_huggingface():
raise RuntimeError("HuggingFace authentication failed.")
if not classifier.load_path_foundation():
raise RuntimeError("Failed to load Path Foundation model.")
model_path = "histopathology_trained_model.keras"
classifier.model = tf.keras.models.load_model(model_path)
print(f"โœ… Loaded model from {model_path}")
# =====================================================
# Helper functions
# =====================================================
def preprocess_for_mwt(image_np):
img = cv2.resize(image_np, (224, 224))
img = Augmentations.Normalization((0, 1))(img)
img = np.array(img, np.float32)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.transpose(2, 0, 1)
img = np.expand_dims(img, axis=0)
return torch.Tensor(img)
def extract_cbf_features(blocks, img_t):
block1, block2, block3, block4, block5 = blocks
with torch.no_grad():
f1 = block1(img_t)
f2 = block2(f1)
f3 = block3(f2)
f4 = block4(f3)
f5 = block5(f4)
p1 = gmp(f1).view(-1)
p2 = gmp(f2).view(-1)
p3 = gap(f3).view(-1)
p4 = gap(f4).view(-1)
p5 = gap(f5).view(-1)
cbf_feature = torch.cat([p1, p2, p3, p4, p5], dim=0)
return cbf_feature.cpu().numpy()
def predict_histopathology(image: Image.Image):
if image.mode != "RGB":
image = image.convert("RGB")
image = image.resize((224, 224))
img_array = np.expand_dims(np.array(image).astype("float32") / 255.0, axis=0)
embeddings = classifier.extract_embeddings(img_array)
prediction_proba = classifier.model.predict(embeddings, verbose=0)[0]
predicted_class = int(np.argmax(prediction_proba))
class_names = ["Benign", "Malignant"]
return {
"model_used": "Breast Cancer Histopathology Classifier",
"prediction": class_names[predicted_class],
"confidence": float(np.max(prediction_proba)),
"probabilities": {
"Benign": float(prediction_proba[0]),
"Malignant": float(prediction_proba[1])
}
}
# =====================================================
# Main endpoint
# =====================================================
@app.post("/predict/")
async def predict(model_name: str = Form(...), file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(BytesIO(contents)).convert("RGB")
image_np = np.array(image)
if model_name == "yolo":
results = yolo_model(image)
detections_json = results[0].to_json()
detections = json.loads(detections_json)
output_filename = f"detected_{uuid.uuid4().hex[:8]}.jpg"
output_path = os.path.join(OUTPUT_DIR, output_filename)
results[0].save(filename=output_path)
return {
"model_used": "YOLO Detection",
"detections": detections,
"annotated_image_url": f"/outputs/{output_filename}"
}
elif model_name == "mwt":
tensor = preprocess_for_mwt(image_np)
with torch.no_grad():
output = mwt_model(tensor.to(device)).cpu()
probs = torch.softmax(output, dim=1)[0]
confidences = {mwt_class_names[i]: float(probs[i]) for i in range(2)}
predicted_label = mwt_class_names[torch.argmax(probs)]
return {"model_used": "MWT Classifier", "prediction": predicted_label, "confidence": confidences}
elif model_name == "cin":
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
results = yolo_colposcopy.predict(source=img, conf=0.7, save=False, verbose=False)
if len(results[0].boxes) == 0:
return {"error": "No cervix detected"}
x1, y1, x2, y2 = map(int, results[0].boxes.xyxy[0].cpu().numpy())
crop = img[y1:y2, x1:x2]
crop = cv2.resize(crop, (224, 224))
img_t = transform(crop).unsqueeze(0).to(device)
f50 = extract_cbf_features(resnet50_blocks, img_t)
f101 = extract_cbf_features(resnet101_blocks, img_t)
f152 = extract_cbf_features(resnet152_blocks, img_t)
features = np.concatenate([f50, f101, f152]).reshape(1, -1)
X_scaled = MinMaxScaler().fit_transform(features)
pred = clf.predict(X_scaled)[0]
proba = clf.predict_proba(X_scaled)[0]
classes = ["CIN1", "CIN2", "CIN3"]
return {
"model_used": "CIN Classifier",
"prediction": classes[pred],
"probabilities": dict(zip(classes, map(float, proba)))
}
elif model_name == "histopathology":
result = predict_histopathology(image)
return result
else:
return JSONResponse(content={"error": "Invalid model name"}, status_code=400)
@app.get("/models")
def get_models():
return {"available_models": ["yolo", "mwt", "cin", "histopathology"]}
@app.get("/health")
def health():
return {"message": "Unified Cervical & Breast Cancer API is running!"}
# After other app.mount()s
app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
app.mount("/assets", StaticFiles(directory="frontend/dist/assets"), name="assets")
from fastapi.staticfiles import StaticFiles
app.mount("/", StaticFiles(directory="frontend/dist", html=True), name="static")
@app.get("/")
async def serve_frontend():
index_path = os.path.join("frontend", "dist", "index.html")
return FileResponse(index_path)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)