Spaces:
Running
Running
| import os | |
| import shutil | |
| for d in ["/tmp/huggingface", "/tmp/Ultralytics", "/tmp/matplotlib", "/tmp/torch", "/root/.cache"]: | |
| shutil.rmtree(d, ignore_errors=True) | |
| os.environ["HF_HOME"] = "/tmp/huggingface" | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface" | |
| os.environ["TORCH_HOME"] = "/tmp/torch" | |
| os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" | |
| os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics" | |
| from huggingface_hub import login | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| from fastapi import FastAPI, File, UploadFile, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from ultralytics import YOLO | |
| from io import BytesIO | |
| from PIL import Image | |
| import uvicorn | |
| import json, os, uuid, numpy as np, torch, cv2, joblib, io, tensorflow as tf | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| from sklearn.preprocessing import MinMaxScaler | |
| from model import MWT as create_model | |
| from augmentations import Augmentations | |
| from model_histo import BreastCancerClassifier # TensorFlow model | |
| # ===================================================== | |
| # App setup | |
| # ===================================================== | |
| app = FastAPI(title="Unified Cervical & Breast Cancer Analysis API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| OUTPUT_DIR = "/tmp/outputs" | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ===================================================== | |
| # Model 1: YOLO (Colposcopy Detection) | |
| # ===================================================== | |
| print("๐น Loading YOLO model...") | |
| yolo_model = YOLO("best2.pt") | |
| # ===================================================== | |
| # Model 2: MWT Classifier | |
| # ===================================================== | |
| print("๐น Loading MWT model...") | |
| mwt_model = create_model(num_classes=2).to(device) | |
| mwt_model.load_state_dict(torch.load("MWTclass2.pth", map_location=device)) | |
| mwt_model.eval() | |
| mwt_class_names = ['Negative', 'Positive'] | |
| # ===================================================== | |
| # Model 3: CIN Classifier | |
| # ===================================================== | |
| print("๐น Loading CIN model...") | |
| clf = joblib.load("logistic_regression_model.pkl") | |
| yolo_colposcopy = YOLO("yolo_colposcopy.pt") | |
| def build_resnet(model_name="resnet50"): | |
| if model_name == "resnet50": | |
| model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
| elif model_name == "resnet101": | |
| model = models.resnet101(weights=models.ResNet101_Weights.DEFAULT) | |
| elif model_name == "resnet152": | |
| model = models.resnet152(weights=models.ResNet152_Weights.DEFAULT) | |
| model.eval().to(device) | |
| return ( | |
| nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool), | |
| model.layer1, model.layer2, model.layer3, model.layer4, | |
| ) | |
| gap = nn.AdaptiveAvgPool2d((1, 1)) | |
| gmp = nn.AdaptiveMaxPool2d((1, 1)) | |
| resnet50_blocks = build_resnet("resnet50") | |
| resnet101_blocks = build_resnet("resnet101") | |
| resnet152_blocks = build_resnet("resnet152") | |
| transform = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # ===================================================== | |
| # Model 4: Histopathology Classifier (TensorFlow) | |
| # ===================================================== | |
| print("๐น Loading Breast Cancer Histopathology model...") | |
| classifier = BreastCancerClassifier(fine_tune=False) | |
| if not classifier.authenticate_huggingface(): | |
| raise RuntimeError("HuggingFace authentication failed.") | |
| if not classifier.load_path_foundation(): | |
| raise RuntimeError("Failed to load Path Foundation model.") | |
| model_path = "histopathology_trained_model.keras" | |
| classifier.model = tf.keras.models.load_model(model_path) | |
| print(f"โ Loaded model from {model_path}") | |
| # ===================================================== | |
| # Helper functions | |
| # ===================================================== | |
| def preprocess_for_mwt(image_np): | |
| img = cv2.resize(image_np, (224, 224)) | |
| img = Augmentations.Normalization((0, 1))(img) | |
| img = np.array(img, np.float32) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = img.transpose(2, 0, 1) | |
| img = np.expand_dims(img, axis=0) | |
| return torch.Tensor(img) | |
| def extract_cbf_features(blocks, img_t): | |
| block1, block2, block3, block4, block5 = blocks | |
| with torch.no_grad(): | |
| f1 = block1(img_t) | |
| f2 = block2(f1) | |
| f3 = block3(f2) | |
| f4 = block4(f3) | |
| f5 = block5(f4) | |
| p1 = gmp(f1).view(-1) | |
| p2 = gmp(f2).view(-1) | |
| p3 = gap(f3).view(-1) | |
| p4 = gap(f4).view(-1) | |
| p5 = gap(f5).view(-1) | |
| cbf_feature = torch.cat([p1, p2, p3, p4, p5], dim=0) | |
| return cbf_feature.cpu().numpy() | |
| def predict_histopathology(image: Image.Image): | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| image = image.resize((224, 224)) | |
| img_array = np.expand_dims(np.array(image).astype("float32") / 255.0, axis=0) | |
| embeddings = classifier.extract_embeddings(img_array) | |
| prediction_proba = classifier.model.predict(embeddings, verbose=0)[0] | |
| predicted_class = int(np.argmax(prediction_proba)) | |
| class_names = ["Benign", "Malignant"] | |
| return { | |
| "model_used": "Breast Cancer Histopathology Classifier", | |
| "prediction": class_names[predicted_class], | |
| "confidence": float(np.max(prediction_proba)), | |
| "probabilities": { | |
| "Benign": float(prediction_proba[0]), | |
| "Malignant": float(prediction_proba[1]) | |
| } | |
| } | |
| # ===================================================== | |
| # Main endpoint | |
| # ===================================================== | |
| async def predict(model_name: str = Form(...), file: UploadFile = File(...)): | |
| contents = await file.read() | |
| image = Image.open(BytesIO(contents)).convert("RGB") | |
| image_np = np.array(image) | |
| if model_name == "yolo": | |
| results = yolo_model(image) | |
| detections_json = results[0].to_json() | |
| detections = json.loads(detections_json) | |
| output_filename = f"detected_{uuid.uuid4().hex[:8]}.jpg" | |
| output_path = os.path.join(OUTPUT_DIR, output_filename) | |
| results[0].save(filename=output_path) | |
| return { | |
| "model_used": "YOLO Detection", | |
| "detections": detections, | |
| "annotated_image_url": f"/outputs/{output_filename}" | |
| } | |
| elif model_name == "mwt": | |
| tensor = preprocess_for_mwt(image_np) | |
| with torch.no_grad(): | |
| output = mwt_model(tensor.to(device)).cpu() | |
| probs = torch.softmax(output, dim=1)[0] | |
| confidences = {mwt_class_names[i]: float(probs[i]) for i in range(2)} | |
| predicted_label = mwt_class_names[torch.argmax(probs)] | |
| return {"model_used": "MWT Classifier", "prediction": predicted_label, "confidence": confidences} | |
| elif model_name == "cin": | |
| nparr = np.frombuffer(contents, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| results = yolo_colposcopy.predict(source=img, conf=0.7, save=False, verbose=False) | |
| if len(results[0].boxes) == 0: | |
| return {"error": "No cervix detected"} | |
| x1, y1, x2, y2 = map(int, results[0].boxes.xyxy[0].cpu().numpy()) | |
| crop = img[y1:y2, x1:x2] | |
| crop = cv2.resize(crop, (224, 224)) | |
| img_t = transform(crop).unsqueeze(0).to(device) | |
| f50 = extract_cbf_features(resnet50_blocks, img_t) | |
| f101 = extract_cbf_features(resnet101_blocks, img_t) | |
| f152 = extract_cbf_features(resnet152_blocks, img_t) | |
| features = np.concatenate([f50, f101, f152]).reshape(1, -1) | |
| X_scaled = MinMaxScaler().fit_transform(features) | |
| pred = clf.predict(X_scaled)[0] | |
| proba = clf.predict_proba(X_scaled)[0] | |
| classes = ["CIN1", "CIN2", "CIN3"] | |
| return { | |
| "model_used": "CIN Classifier", | |
| "prediction": classes[pred], | |
| "probabilities": dict(zip(classes, map(float, proba))) | |
| } | |
| elif model_name == "histopathology": | |
| result = predict_histopathology(image) | |
| return result | |
| else: | |
| return JSONResponse(content={"error": "Invalid model name"}, status_code=400) | |
| def get_models(): | |
| return {"available_models": ["yolo", "mwt", "cin", "histopathology"]} | |
| def health(): | |
| return {"message": "Unified Cervical & Breast Cancer API is running!"} | |
| # After other app.mount()s | |
| app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs") | |
| app.mount("/assets", StaticFiles(directory="frontend/dist/assets"), name="assets") | |
| from fastapi.staticfiles import StaticFiles | |
| app.mount("/", StaticFiles(directory="frontend/dist", html=True), name="static") | |
| async def serve_frontend(): | |
| index_path = os.path.join("frontend", "dist", "index.html") | |
| return FileResponse(index_path) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |