import os import shutil for d in ["/tmp/huggingface", "/tmp/Ultralytics", "/tmp/matplotlib", "/tmp/torch", "/root/.cache"]: shutil.rmtree(d, ignore_errors=True) os.environ["HF_HOME"] = "/tmp/huggingface" os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface" os.environ["TORCH_HOME"] = "/tmp/torch" os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics" from huggingface_hub import login hf_token = os.getenv("HF_TOKEN") if hf_token: login(token=hf_token) from fastapi import FastAPI, File, UploadFile, Form from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from ultralytics import YOLO from io import BytesIO from PIL import Image import uvicorn import json, os, uuid, numpy as np, torch, cv2, joblib, io, tensorflow as tf import torch.nn as nn import torchvision.transforms as transforms import torchvision.models as models from sklearn.preprocessing import MinMaxScaler from model import MWT as create_model from augmentations import Augmentations from model_histo import BreastCancerClassifier # TensorFlow model # ===================================================== # App setup # ===================================================== app = FastAPI(title="Unified Cervical & Breast Cancer Analysis API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) OUTPUT_DIR = "/tmp/outputs" os.makedirs(OUTPUT_DIR, exist_ok=True) app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ===================================================== # Model 1: YOLO (Colposcopy Detection) # ===================================================== print("🔹 Loading YOLO model...") yolo_model = YOLO("best2.pt") # ===================================================== # Model 2: MWT Classifier # ===================================================== print("🔹 Loading MWT model...") mwt_model = create_model(num_classes=2).to(device) mwt_model.load_state_dict(torch.load("MWTclass2.pth", map_location=device)) mwt_model.eval() mwt_class_names = ['Negative', 'Positive'] # ===================================================== # Model 3: CIN Classifier # ===================================================== print("🔹 Loading CIN model...") clf = joblib.load("logistic_regression_model.pkl") yolo_colposcopy = YOLO("yolo_colposcopy.pt") def build_resnet(model_name="resnet50"): if model_name == "resnet50": model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) elif model_name == "resnet101": model = models.resnet101(weights=models.ResNet101_Weights.DEFAULT) elif model_name == "resnet152": model = models.resnet152(weights=models.ResNet152_Weights.DEFAULT) model.eval().to(device) return ( nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool), model.layer1, model.layer2, model.layer3, model.layer4, ) gap = nn.AdaptiveAvgPool2d((1, 1)) gmp = nn.AdaptiveMaxPool2d((1, 1)) resnet50_blocks = build_resnet("resnet50") resnet101_blocks = build_resnet("resnet101") resnet152_blocks = build_resnet("resnet152") transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ===================================================== # Model 4: Histopathology Classifier (TensorFlow) # ===================================================== print("🔹 Loading Breast Cancer Histopathology model...") classifier = BreastCancerClassifier(fine_tune=False) if not classifier.authenticate_huggingface(): raise RuntimeError("HuggingFace authentication failed.") if not classifier.load_path_foundation(): raise RuntimeError("Failed to load Path Foundation model.") model_path = "histopathology_trained_model.keras" classifier.model = tf.keras.models.load_model(model_path) print(f"✅ Loaded model from {model_path}") # ===================================================== # Helper functions # ===================================================== def preprocess_for_mwt(image_np): img = cv2.resize(image_np, (224, 224)) img = Augmentations.Normalization((0, 1))(img) img = np.array(img, np.float32) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img.transpose(2, 0, 1) img = np.expand_dims(img, axis=0) return torch.Tensor(img) def extract_cbf_features(blocks, img_t): block1, block2, block3, block4, block5 = blocks with torch.no_grad(): f1 = block1(img_t) f2 = block2(f1) f3 = block3(f2) f4 = block4(f3) f5 = block5(f4) p1 = gmp(f1).view(-1) p2 = gmp(f2).view(-1) p3 = gap(f3).view(-1) p4 = gap(f4).view(-1) p5 = gap(f5).view(-1) cbf_feature = torch.cat([p1, p2, p3, p4, p5], dim=0) return cbf_feature.cpu().numpy() def predict_histopathology(image: Image.Image): if image.mode != "RGB": image = image.convert("RGB") image = image.resize((224, 224)) img_array = np.expand_dims(np.array(image).astype("float32") / 255.0, axis=0) embeddings = classifier.extract_embeddings(img_array) prediction_proba = classifier.model.predict(embeddings, verbose=0)[0] predicted_class = int(np.argmax(prediction_proba)) class_names = ["Benign", "Malignant"] return { "model_used": "Breast Cancer Histopathology Classifier", "prediction": class_names[predicted_class], "confidence": float(np.max(prediction_proba)), "probabilities": { "Benign": float(prediction_proba[0]), "Malignant": float(prediction_proba[1]) } } # ===================================================== # Main endpoint # ===================================================== @app.post("/predict/") async def predict(model_name: str = Form(...), file: UploadFile = File(...)): contents = await file.read() image = Image.open(BytesIO(contents)).convert("RGB") image_np = np.array(image) if model_name == "yolo": results = yolo_model(image) detections_json = results[0].to_json() detections = json.loads(detections_json) output_filename = f"detected_{uuid.uuid4().hex[:8]}.jpg" output_path = os.path.join(OUTPUT_DIR, output_filename) results[0].save(filename=output_path) return { "model_used": "YOLO Detection", "detections": detections, "annotated_image_url": f"/outputs/{output_filename}" } elif model_name == "mwt": tensor = preprocess_for_mwt(image_np) with torch.no_grad(): output = mwt_model(tensor.to(device)).cpu() probs = torch.softmax(output, dim=1)[0] confidences = {mwt_class_names[i]: float(probs[i]) for i in range(2)} predicted_label = mwt_class_names[torch.argmax(probs)] return {"model_used": "MWT Classifier", "prediction": predicted_label, "confidence": confidences} elif model_name == "cin": nparr = np.frombuffer(contents, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) results = yolo_colposcopy.predict(source=img, conf=0.7, save=False, verbose=False) if len(results[0].boxes) == 0: return {"error": "No cervix detected"} x1, y1, x2, y2 = map(int, results[0].boxes.xyxy[0].cpu().numpy()) crop = img[y1:y2, x1:x2] crop = cv2.resize(crop, (224, 224)) img_t = transform(crop).unsqueeze(0).to(device) f50 = extract_cbf_features(resnet50_blocks, img_t) f101 = extract_cbf_features(resnet101_blocks, img_t) f152 = extract_cbf_features(resnet152_blocks, img_t) features = np.concatenate([f50, f101, f152]).reshape(1, -1) X_scaled = MinMaxScaler().fit_transform(features) pred = clf.predict(X_scaled)[0] proba = clf.predict_proba(X_scaled)[0] classes = ["CIN1", "CIN2", "CIN3"] return { "model_used": "CIN Classifier", "prediction": classes[pred], "probabilities": dict(zip(classes, map(float, proba))) } elif model_name == "histopathology": result = predict_histopathology(image) return result else: return JSONResponse(content={"error": "Invalid model name"}, status_code=400) @app.get("/models") def get_models(): return {"available_models": ["yolo", "mwt", "cin", "histopathology"]} @app.get("/health") def health(): return {"message": "Unified Cervical & Breast Cancer API is running!"} # After other app.mount()s app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs") app.mount("/assets", StaticFiles(directory="frontend/dist/assets"), name="assets") from fastapi.staticfiles import StaticFiles app.mount("/", StaticFiles(directory="frontend/dist", html=True), name="static") @app.get("/") async def serve_frontend(): index_path = os.path.join("frontend", "dist", "index.html") return FileResponse(index_path) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)