Spaces:
Running
Running
Commit
·
bf5da6b
1
Parent(s):
7f72e7d
Initial Space
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- Dockerfile +43 -0
- README.md +3 -4
- backend/MWTclass2.pth +3 -0
- backend/__pycache__/app0.cpython-312.pyc +0 -0
- backend/__pycache__/app2.cpython-312.pyc +0 -0
- backend/__pycache__/app3.cpython-312.pyc +0 -0
- backend/__pycache__/app4.cpython-312.pyc +0 -0
- backend/__pycache__/augmentations.cpython-312.pyc +0 -0
- backend/__pycache__/model.cpython-312.pyc +0 -0
- backend/__pycache__/model_histo.cpython-312.pyc +0 -0
- backend/app.py +241 -0
- backend/augmentations.py +328 -0
- backend/best2.pt +3 -0
- backend/histopathology_trained_model.keras +3 -0
- backend/logistic_regression_model.pkl +3 -0
- backend/model.py +521 -0
- backend/model_histo.py +1495 -0
- backend/requirements.txt +14 -0
- backend/yolo_colposcopy.pt +3 -0
- frontend/.eslintrc.cjs +18 -0
- frontend/.gitignore +24 -0
- frontend/README.md +8 -0
- frontend/index.html +13 -0
- frontend/package-lock.json +0 -0
- frontend/package.json +35 -0
- frontend/postcss.config.js +6 -0
- frontend/public/banner.jpeg +0 -0
- frontend/public/black_logo.png +0 -0
- frontend/public/colpo/colp1.jpg +0 -0
- frontend/public/colpo/colp2.jpg +0 -0
- frontend/public/colpo/colp3.jpg +0 -0
- frontend/public/cyto/cyt1.jpg +3 -0
- frontend/public/cyto/cyt2.png +3 -0
- frontend/public/cyto/cyt3.png +3 -0
- frontend/public/histo/hist1.png +0 -0
- frontend/public/histo/hist2.png +0 -0
- frontend/public/histo/hist3.jpg +0 -0
- frontend/public/manalife_LOGO.jpg +0 -0
- frontend/public/white_logo.png +0 -0
- frontend/src/App.tsx +125 -0
- frontend/src/AppRouter.tsx +10 -0
- frontend/src/components/Footer.tsx +50 -0
- frontend/src/components/Header.tsx +40 -0
- frontend/src/components/ResultsPanel.tsx +143 -0
- frontend/src/components/Sidebar.tsx +54 -0
- frontend/src/components/UploadSection.tsx +194 -0
- frontend/src/components/progressbar.tsx +61 -0
- frontend/src/index.css +5 -0
- frontend/src/index.tsx +5 -0
.gitattributes
CHANGED
|
@@ -33,5 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
| 36 |
backend/outputs/** filter=lfs diff=lfs merge=lfs -text
|
| 37 |
frontend/public/cyto/** filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
backend/histopathology_trained_model.keras filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
frontend/public/cyto/*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
frontend/public/cyto/*.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
backend/outputs/** filter=lfs diff=lfs merge=lfs -text
|
| 40 |
frontend/public/cyto/** filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -----------------------------
|
| 2 |
+
# 1️⃣ Build Frontend
|
| 3 |
+
# -----------------------------
|
| 4 |
+
FROM node:18-bullseye AS frontend-builder
|
| 5 |
+
WORKDIR /app/frontend
|
| 6 |
+
COPY frontend/package*.json ./
|
| 7 |
+
RUN npm install
|
| 8 |
+
COPY frontend/ .
|
| 9 |
+
RUN npm run build
|
| 10 |
+
|
| 11 |
+
# -----------------------------
|
| 12 |
+
# 2️⃣ Build Backend
|
| 13 |
+
# -----------------------------
|
| 14 |
+
FROM python:3.10-slim-bullseye
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Install essential system libraries for OpenCV, YOLO, and Ultralytics
|
| 19 |
+
RUN apt-get update && apt-get install -y \
|
| 20 |
+
libgl1 \
|
| 21 |
+
libglib2.0-0 \
|
| 22 |
+
libgomp1 \
|
| 23 |
+
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
| 24 |
+
|
| 25 |
+
# Copy backend source code
|
| 26 |
+
COPY backend/ .
|
| 27 |
+
|
| 28 |
+
# Copy built frontend into the right folder for FastAPI
|
| 29 |
+
# ✅ this must match your app.mount() path in app.py
|
| 30 |
+
COPY --from=frontend-builder /app/frontend/dist ./frontend/dist
|
| 31 |
+
|
| 32 |
+
# Install Python dependencies
|
| 33 |
+
RUN pip install --upgrade pip
|
| 34 |
+
RUN pip install -r requirements.txt || pip install -r backend/requirements.txt || true
|
| 35 |
+
|
| 36 |
+
# Install runtime dependencies explicitly
|
| 37 |
+
RUN pip install --no-cache-dir fastapi uvicorn python-multipart ultralytics opencv-python-headless pillow numpy scikit-learn tensorflow keras
|
| 38 |
+
|
| 39 |
+
# Hugging Face Spaces expect port 7860
|
| 40 |
+
EXPOSE 7860
|
| 41 |
+
|
| 42 |
+
# Run FastAPI app
|
| 43 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,11 +1,10 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
-
short_description: Manalife's AI Pathology Assistant
|
| 9 |
---
|
| 10 |
|
| 11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Proj Demo
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: purple
|
| 5 |
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
backend/MWTclass2.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:09fae21e537056bc61f9ab4bb08249b591697bb087546cf639c599c70b8c6a2c
|
| 3 |
+
size 79777797
|
backend/__pycache__/app0.cpython-312.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
backend/__pycache__/app2.cpython-312.pyc
ADDED
|
Binary file (3.66 kB). View file
|
|
|
backend/__pycache__/app3.cpython-312.pyc
ADDED
|
Binary file (4 kB). View file
|
|
|
backend/__pycache__/app4.cpython-312.pyc
ADDED
|
Binary file (9.17 kB). View file
|
|
|
backend/__pycache__/augmentations.cpython-312.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
backend/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (31.3 kB). View file
|
|
|
backend/__pycache__/model_histo.cpython-312.pyc
ADDED
|
Binary file (69.2 kB). View file
|
|
|
backend/app.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, File, UploadFile, Form
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from fastapi.responses import JSONResponse
|
| 4 |
+
from fastapi.staticfiles import StaticFiles
|
| 5 |
+
from ultralytics import YOLO
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import uvicorn
|
| 9 |
+
import json, os, uuid, numpy as np, torch, cv2, joblib, io, tensorflow as tf
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
import torchvision.models as models
|
| 13 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 14 |
+
from model import MWT as create_model
|
| 15 |
+
from augmentations import Augmentations
|
| 16 |
+
from model_histo import BreastCancerClassifier # TensorFlow model
|
| 17 |
+
|
| 18 |
+
from huggingface_hub import login
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 22 |
+
if hf_token:
|
| 23 |
+
login(token=hf_token)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# =====================================================
|
| 27 |
+
# App setup
|
| 28 |
+
# =====================================================
|
| 29 |
+
app = FastAPI(title="Unified Cervical & Breast Cancer Analysis API")
|
| 30 |
+
|
| 31 |
+
app.add_middleware(
|
| 32 |
+
CORSMiddleware,
|
| 33 |
+
allow_origins=["*"],
|
| 34 |
+
allow_credentials=True,
|
| 35 |
+
allow_methods=["*"],
|
| 36 |
+
allow_headers=["*"],
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
OUTPUT_DIR = "outputs"
|
| 40 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 41 |
+
app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
|
| 42 |
+
|
| 43 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
+
|
| 45 |
+
# =====================================================
|
| 46 |
+
# Model 1: YOLO (Colposcopy Detection)
|
| 47 |
+
# =====================================================
|
| 48 |
+
print("🔹 Loading YOLO model...")
|
| 49 |
+
yolo_model = YOLO("best2.pt")
|
| 50 |
+
|
| 51 |
+
# =====================================================
|
| 52 |
+
# Model 2: MWT Classifier
|
| 53 |
+
# =====================================================
|
| 54 |
+
print("🔹 Loading MWT model...")
|
| 55 |
+
mwt_model = create_model(num_classes=2).to(device)
|
| 56 |
+
mwt_model.load_state_dict(torch.load("MWTclass2.pth", map_location=device))
|
| 57 |
+
mwt_model.eval()
|
| 58 |
+
mwt_class_names = ['neg', 'pos']
|
| 59 |
+
|
| 60 |
+
# =====================================================
|
| 61 |
+
# Model 3: CIN Classifier
|
| 62 |
+
# =====================================================
|
| 63 |
+
print("🔹 Loading CIN model...")
|
| 64 |
+
clf = joblib.load("logistic_regression_model.pkl")
|
| 65 |
+
yolo_colposcopy = YOLO("yolo_colposcopy.pt")
|
| 66 |
+
|
| 67 |
+
def build_resnet(model_name="resnet50"):
|
| 68 |
+
if model_name == "resnet50":
|
| 69 |
+
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
| 70 |
+
elif model_name == "resnet101":
|
| 71 |
+
model = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
|
| 72 |
+
elif model_name == "resnet152":
|
| 73 |
+
model = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
|
| 74 |
+
model.eval().to(device)
|
| 75 |
+
return (
|
| 76 |
+
nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool),
|
| 77 |
+
model.layer1, model.layer2, model.layer3, model.layer4,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
gap = nn.AdaptiveAvgPool2d((1, 1))
|
| 81 |
+
gmp = nn.AdaptiveMaxPool2d((1, 1))
|
| 82 |
+
resnet50_blocks = build_resnet("resnet50")
|
| 83 |
+
resnet101_blocks = build_resnet("resnet101")
|
| 84 |
+
resnet152_blocks = build_resnet("resnet152")
|
| 85 |
+
|
| 86 |
+
transform = transforms.Compose([
|
| 87 |
+
transforms.ToPILImage(),
|
| 88 |
+
transforms.Resize((224, 224)),
|
| 89 |
+
transforms.ToTensor(),
|
| 90 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 91 |
+
std=[0.229, 0.224, 0.225]),
|
| 92 |
+
])
|
| 93 |
+
|
| 94 |
+
# =====================================================
|
| 95 |
+
# Model 4: Histopathology Classifier (TensorFlow)
|
| 96 |
+
# =====================================================
|
| 97 |
+
print("🔹 Loading Breast Cancer Histopathology model...")
|
| 98 |
+
classifier = BreastCancerClassifier(fine_tune=False)
|
| 99 |
+
if not classifier.authenticate_huggingface():
|
| 100 |
+
raise RuntimeError("HuggingFace authentication failed.")
|
| 101 |
+
if not classifier.load_path_foundation():
|
| 102 |
+
raise RuntimeError("Failed to load Path Foundation model.")
|
| 103 |
+
model_path = "histopathology_trained_model.keras"
|
| 104 |
+
classifier.model = tf.keras.models.load_model(model_path)
|
| 105 |
+
print(f"✅ Loaded model from {model_path}")
|
| 106 |
+
|
| 107 |
+
# =====================================================
|
| 108 |
+
# Helper functions
|
| 109 |
+
# =====================================================
|
| 110 |
+
def preprocess_for_mwt(image_np):
|
| 111 |
+
img = cv2.resize(image_np, (224, 224))
|
| 112 |
+
img = Augmentations.Normalization((0, 1))(img)
|
| 113 |
+
img = np.array(img, np.float32)
|
| 114 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 115 |
+
img = img.transpose(2, 0, 1)
|
| 116 |
+
img = np.expand_dims(img, axis=0)
|
| 117 |
+
return torch.Tensor(img)
|
| 118 |
+
|
| 119 |
+
def extract_cbf_features(blocks, img_t):
|
| 120 |
+
block1, block2, block3, block4, block5 = blocks
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
f1 = block1(img_t)
|
| 123 |
+
f2 = block2(f1)
|
| 124 |
+
f3 = block3(f2)
|
| 125 |
+
f4 = block4(f3)
|
| 126 |
+
f5 = block5(f4)
|
| 127 |
+
p1 = gmp(f1).view(-1)
|
| 128 |
+
p2 = gmp(f2).view(-1)
|
| 129 |
+
p3 = gap(f3).view(-1)
|
| 130 |
+
p4 = gap(f4).view(-1)
|
| 131 |
+
p5 = gap(f5).view(-1)
|
| 132 |
+
cbf_feature = torch.cat([p1, p2, p3, p4, p5], dim=0)
|
| 133 |
+
return cbf_feature.cpu().numpy()
|
| 134 |
+
|
| 135 |
+
def predict_histopathology(image: Image.Image):
|
| 136 |
+
if image.mode != "RGB":
|
| 137 |
+
image = image.convert("RGB")
|
| 138 |
+
image = image.resize((224, 224))
|
| 139 |
+
img_array = np.expand_dims(np.array(image).astype("float32") / 255.0, axis=0)
|
| 140 |
+
embeddings = classifier.extract_embeddings(img_array)
|
| 141 |
+
prediction_proba = classifier.model.predict(embeddings, verbose=0)[0]
|
| 142 |
+
predicted_class = int(np.argmax(prediction_proba))
|
| 143 |
+
class_names = ["Benign", "Malignant"]
|
| 144 |
+
return {
|
| 145 |
+
"model_used": "Breast Cancer Histopathology Classifier",
|
| 146 |
+
"prediction": class_names[predicted_class],
|
| 147 |
+
"confidence": float(np.max(prediction_proba)),
|
| 148 |
+
"probabilities": {
|
| 149 |
+
"Benign": float(prediction_proba[0]),
|
| 150 |
+
"Malignant": float(prediction_proba[1])
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
# =====================================================
|
| 155 |
+
# Main endpoint
|
| 156 |
+
# =====================================================
|
| 157 |
+
@app.post("/predict/")
|
| 158 |
+
async def predict(model_name: str = Form(...), file: UploadFile = File(...)):
|
| 159 |
+
contents = await file.read()
|
| 160 |
+
image = Image.open(BytesIO(contents)).convert("RGB")
|
| 161 |
+
image_np = np.array(image)
|
| 162 |
+
|
| 163 |
+
if model_name == "yolo":
|
| 164 |
+
results = yolo_model(image)
|
| 165 |
+
detections_json = results[0].to_json()
|
| 166 |
+
detections = json.loads(detections_json)
|
| 167 |
+
output_filename = f"detected_{uuid.uuid4().hex[:8]}.jpg"
|
| 168 |
+
output_path = os.path.join(OUTPUT_DIR, output_filename)
|
| 169 |
+
results[0].save(filename=output_path)
|
| 170 |
+
return {
|
| 171 |
+
"model_used": "YOLO Detection",
|
| 172 |
+
"detections": detections,
|
| 173 |
+
"annotated_image_url": f"/outputs/{output_filename}"
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
elif model_name == "mwt":
|
| 177 |
+
tensor = preprocess_for_mwt(image_np)
|
| 178 |
+
with torch.no_grad():
|
| 179 |
+
output = mwt_model(tensor.to(device)).cpu()
|
| 180 |
+
probs = torch.softmax(output, dim=1)[0]
|
| 181 |
+
confidences = {mwt_class_names[i]: float(probs[i]) for i in range(2)}
|
| 182 |
+
predicted_label = mwt_class_names[torch.argmax(probs)]
|
| 183 |
+
return {"model_used": "MWT Classifier", "prediction": predicted_label, "confidence": confidences}
|
| 184 |
+
|
| 185 |
+
elif model_name == "cin":
|
| 186 |
+
nparr = np.frombuffer(contents, np.uint8)
|
| 187 |
+
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 188 |
+
results = yolo_colposcopy.predict(source=img, conf=0.7, save=False, verbose=False)
|
| 189 |
+
if len(results[0].boxes) == 0:
|
| 190 |
+
return {"error": "No cervix detected"}
|
| 191 |
+
x1, y1, x2, y2 = map(int, results[0].boxes.xyxy[0].cpu().numpy())
|
| 192 |
+
crop = img[y1:y2, x1:x2]
|
| 193 |
+
crop = cv2.resize(crop, (224, 224))
|
| 194 |
+
img_t = transform(crop).unsqueeze(0).to(device)
|
| 195 |
+
f50 = extract_cbf_features(resnet50_blocks, img_t)
|
| 196 |
+
f101 = extract_cbf_features(resnet101_blocks, img_t)
|
| 197 |
+
f152 = extract_cbf_features(resnet152_blocks, img_t)
|
| 198 |
+
features = np.concatenate([f50, f101, f152]).reshape(1, -1)
|
| 199 |
+
X_scaled = MinMaxScaler().fit_transform(features)
|
| 200 |
+
pred = clf.predict(X_scaled)[0]
|
| 201 |
+
proba = clf.predict_proba(X_scaled)[0]
|
| 202 |
+
classes = ["CIN1", "CIN2", "CIN3"]
|
| 203 |
+
return {
|
| 204 |
+
"model_used": "CIN Classifier",
|
| 205 |
+
"prediction": classes[pred],
|
| 206 |
+
"probabilities": dict(zip(classes, map(float, proba)))
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
elif model_name == "histopathology":
|
| 210 |
+
result = predict_histopathology(image)
|
| 211 |
+
return result
|
| 212 |
+
|
| 213 |
+
else:
|
| 214 |
+
return JSONResponse(content={"error": "Invalid model name"}, status_code=400)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
@app.get("/models")
|
| 218 |
+
def get_models():
|
| 219 |
+
return {"available_models": ["yolo", "mwt", "cin", "histopathology"]}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@app.get("/health")
|
| 223 |
+
def health():
|
| 224 |
+
return {"message": "Unified Cervical & Breast Cancer API is running!"}
|
| 225 |
+
|
| 226 |
+
# After other app.mount()s
|
| 227 |
+
app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
|
| 228 |
+
app.mount("/assets", StaticFiles(directory="frontend/dist/assets"), name="assets")
|
| 229 |
+
from fastapi.staticfiles import StaticFiles
|
| 230 |
+
|
| 231 |
+
app.mount("/", StaticFiles(directory="frontend/dist", html=True), name="static")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@app.get("/")
|
| 235 |
+
async def serve_frontend():
|
| 236 |
+
index_path = os.path.join("frontend", "dist", "index.html")
|
| 237 |
+
return FileResponse(index_path)
|
| 238 |
+
|
| 239 |
+
if __name__ == "__main__":
|
| 240 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 241 |
+
|
backend/augmentations.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
This file is a part of project "Aided-Diagnosis-System-for-Cervical-Cancer-Screening".
|
| 4 |
+
See https://github.com/ShenghuaCheng/Aided-Diagnosis-System-for-Cervical-Cancer-Screening for more information.
|
| 5 |
+
File name: augmentations.py
|
| 6 |
+
Description: augmentation functions.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import functools
|
| 10 |
+
import random
|
| 11 |
+
import cv2
|
| 12 |
+
import numpy as np
|
| 13 |
+
from skimage.exposure import adjust_gamma
|
| 14 |
+
from loguru import logger
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"Augmentations",
|
| 18 |
+
"StylisticTrans",
|
| 19 |
+
"SpatialTrans",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Augmentations:
|
| 24 |
+
"""
|
| 25 |
+
All parameters in each augmentations have been fixed to a suitable range.
|
| 26 |
+
img = [size, size, ch]
|
| 27 |
+
ch = 3: only img
|
| 28 |
+
ch = 4: img with mask at 4th dim
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
def Compose(*funcs):
|
| 33 |
+
funcs = list(funcs)
|
| 34 |
+
func_names = [f.__name__ for f in funcs]
|
| 35 |
+
# ensure the norm opt is the last opt
|
| 36 |
+
if 'norm' in func_names:
|
| 37 |
+
idx = func_names.index('norm')
|
| 38 |
+
funcs = funcs[:idx] + funcs[idx:] + [funcs[idx]]
|
| 39 |
+
|
| 40 |
+
def compose(img: np.ndarray):
|
| 41 |
+
return functools.reduce(lambda f, g: lambda x: g(f(x)), funcs)(img)
|
| 42 |
+
|
| 43 |
+
return compose
|
| 44 |
+
|
| 45 |
+
"""
|
| 46 |
+
# ===========================================================================================================
|
| 47 |
+
# random stylistic augmentations
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def RandomGamma(p: float = 0.5):
|
| 52 |
+
def random_gamma(img: np.ndarray):
|
| 53 |
+
if random.random() < p:
|
| 54 |
+
gamma = 0.6 + random.random() * 0.6
|
| 55 |
+
img[..., :3] = StylisticTrans.gamma_adjust(img[..., :3], gamma)
|
| 56 |
+
return img
|
| 57 |
+
|
| 58 |
+
return random_gamma
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def RandomSharp(p: float = 0.5):
|
| 62 |
+
def random_sharp(img: np.ndarray):
|
| 63 |
+
if random.random() < p:
|
| 64 |
+
sigma = 8.3 + random.random() * 0.4
|
| 65 |
+
img[..., :3] = StylisticTrans.sharp(img[..., :3], sigma)
|
| 66 |
+
return img
|
| 67 |
+
|
| 68 |
+
return random_sharp
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def RandomGaussainBlur(p: float = 0.5):
|
| 72 |
+
def random_gaussian_blur(img: np.ndarray):
|
| 73 |
+
if random.random() < p:
|
| 74 |
+
sigma = 0.1 + random.random() * 1
|
| 75 |
+
img[..., :3] = StylisticTrans.gaussian_blur(img[..., :3], sigma)
|
| 76 |
+
return img
|
| 77 |
+
|
| 78 |
+
return random_gaussian_blur
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def RandomHSVDisturb(p: float = 0.5):
|
| 82 |
+
def random_hsv_disturb(img: np.ndarray):
|
| 83 |
+
if random.random() < p:
|
| 84 |
+
k = np.random.random(3) * [0.1, 0.8, 0.45] + [0.95, 0.7, 0.75]
|
| 85 |
+
b = np.random.random(3) * [6, 20, 18] + [-3, -10, -10]
|
| 86 |
+
img[..., :3] = StylisticTrans.hsv_disturb(img[..., :3], k.tolist(), b.tolist())
|
| 87 |
+
return img
|
| 88 |
+
|
| 89 |
+
return random_hsv_disturb
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def RandomRGBSwitch(p: float = 0.5):
|
| 93 |
+
def random_rgb_switch(img: np.ndarray):
|
| 94 |
+
if random.random() < p:
|
| 95 |
+
bgr_seq = list(range(3))
|
| 96 |
+
random.shuffle(bgr_seq)
|
| 97 |
+
img[..., :3] = StylisticTrans.bgr_switch(img[..., :3], bgr_seq)
|
| 98 |
+
return img
|
| 99 |
+
|
| 100 |
+
return random_rgb_switch
|
| 101 |
+
|
| 102 |
+
"""
|
| 103 |
+
# ===========================================================================================================
|
| 104 |
+
# random spatial augmentations, funcs can be implement to tiles and their masks.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
@staticmethod
|
| 108 |
+
def RandomRotate90(p: float = 0.5):
|
| 109 |
+
def random_rotate90(img: np.ndarray):
|
| 110 |
+
if random.random() < p:
|
| 111 |
+
angle = 90 * random.randint(1, 3)
|
| 112 |
+
img = SpatialTrans.rotate(img, angle)
|
| 113 |
+
return img
|
| 114 |
+
|
| 115 |
+
return random_rotate90
|
| 116 |
+
|
| 117 |
+
@staticmethod
|
| 118 |
+
def RandomHorizontalFlip(p: float = 0.5):
|
| 119 |
+
def random_horizontal_flip(img: np.ndarray):
|
| 120 |
+
if random.random() < p:
|
| 121 |
+
img = SpatialTrans.flip(img, 0)
|
| 122 |
+
return img
|
| 123 |
+
|
| 124 |
+
return random_horizontal_flip
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def RandomVerticalFlip(p: float = 0.5):
|
| 128 |
+
def random_vertical_flip(img: np.ndarray):
|
| 129 |
+
if random.random() < p:
|
| 130 |
+
img = SpatialTrans.flip(img, 1)
|
| 131 |
+
return img
|
| 132 |
+
|
| 133 |
+
return random_vertical_flip
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def RandomScale(p: float = 0.5):
|
| 137 |
+
def random_scale(img: np.ndarray):
|
| 138 |
+
if random.random() < p:
|
| 139 |
+
ratio = 0.8 + random.random() * 0.4
|
| 140 |
+
img = SpatialTrans.scale(img, ratio, True)
|
| 141 |
+
return img
|
| 142 |
+
|
| 143 |
+
return random_scale
|
| 144 |
+
|
| 145 |
+
@staticmethod
|
| 146 |
+
def RandomCrop(p: float = 1., size: tuple = (512, 512)):
|
| 147 |
+
def random_crop(img: np.ndarray):
|
| 148 |
+
if random.random() < p:
|
| 149 |
+
# for a large FOV, control the translate range
|
| 150 |
+
new_shape = list(img.shape[:2][::-1])
|
| 151 |
+
if img.shape[0] > size[1] * 1.5:
|
| 152 |
+
new_shape[1] = int(size[1] * 1.5)
|
| 153 |
+
if img.shape[1] > size[0] * 1.5:
|
| 154 |
+
new_shape[0] = int(size[0] * 1.5)
|
| 155 |
+
img = SpatialTrans.center_crop(img.copy(), tuple(new_shape))
|
| 156 |
+
# do translate
|
| 157 |
+
xy = np.random.random(2) * (np.array(img.shape[:2]) - list(size))
|
| 158 |
+
bbox = tuple(xy.astype(np.int).tolist() + list(size))
|
| 159 |
+
img = SpatialTrans.crop(img, bbox)
|
| 160 |
+
else:
|
| 161 |
+
img = SpatialTrans.center_crop(img, size)
|
| 162 |
+
return img
|
| 163 |
+
|
| 164 |
+
return random_crop
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def Normalization(rng: list = [-1, 1]):
|
| 168 |
+
def norm(img: np.ndarray):
|
| 169 |
+
img = StylisticTrans.normalization(img, rng)
|
| 170 |
+
return img
|
| 171 |
+
|
| 172 |
+
return norm
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def CenterCrop(size: tuple = (512, 512)):
|
| 176 |
+
def center_crop(img: np.ndarray):
|
| 177 |
+
img = SpatialTrans.center_crop(img, size)
|
| 178 |
+
return img
|
| 179 |
+
|
| 180 |
+
return center_crop
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class StylisticTrans:
|
| 184 |
+
# TODO Some implementations of augmentation need a efficient way
|
| 185 |
+
"""
|
| 186 |
+
set of augmentations applied to the content of image
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
@staticmethod
|
| 190 |
+
def gamma_adjust(img: np.ndarray, gamma: float):
|
| 191 |
+
""" adjust gamma
|
| 192 |
+
:param img: a ndarray, better a BGR
|
| 193 |
+
:param gamma: gamma, recommended value 0.6, range [0.6, 1.2]
|
| 194 |
+
:return: a ndarray
|
| 195 |
+
"""
|
| 196 |
+
return adjust_gamma(img.copy(), gamma)
|
| 197 |
+
|
| 198 |
+
@staticmethod
|
| 199 |
+
def sharp(img: np.ndarray, sigma: float):
|
| 200 |
+
"""sharp image
|
| 201 |
+
:param img: a ndarray, better a BGR
|
| 202 |
+
:param sigma: sharp degree, recommended range [8.3, 8.7]
|
| 203 |
+
:return: a ndarray
|
| 204 |
+
"""
|
| 205 |
+
kernel = np.array([[-1, -1, -1], [-1, sigma, -1], [-1, -1, -1]], np.float32) / (sigma - 8) # 锐化
|
| 206 |
+
return cv2.filter2D(img.copy(), -1, kernel=kernel)
|
| 207 |
+
|
| 208 |
+
@staticmethod
|
| 209 |
+
def gaussian_blur(img: np.ndarray, sigma: float):
|
| 210 |
+
"""blurring image
|
| 211 |
+
:param img: a ndarray, better a BGR
|
| 212 |
+
:param sigma: blurring degree, recommended range [0.1, 1.1]
|
| 213 |
+
:return: a ndarray
|
| 214 |
+
"""
|
| 215 |
+
return cv2.GaussianBlur(img.copy(), (int(6 * np.ceil(sigma) + 1), int(6 * np.ceil(sigma) + 1)), sigma)
|
| 216 |
+
|
| 217 |
+
@staticmethod
|
| 218 |
+
def hsv_disturb(img: np.ndarray, k: list, b: list):
|
| 219 |
+
""" disturb the hsv value
|
| 220 |
+
:param img: a BGR ndarray
|
| 221 |
+
:param k: low_b = [0.95, 0.7, 0.75] ,upper_b = [1.05, 1.5, 1.2]
|
| 222 |
+
:param b: low_b = [-3, -10, -10] ,upper_b = [3, 10, 8]
|
| 223 |
+
:return: a BGR ndarray
|
| 224 |
+
"""
|
| 225 |
+
img = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2HSV)
|
| 226 |
+
img = img.astype(np.float)
|
| 227 |
+
for ch in range(3):
|
| 228 |
+
img[..., ch] = k[ch] * img[..., ch] + b[ch]
|
| 229 |
+
img = np.uint8(np.clip(img, np.array([0, 1, 1]), np.array([180, 255, 255])))
|
| 230 |
+
return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def bgr_switch(img: np.ndarray, bgr_seq: list):
|
| 234 |
+
""" switch bgr
|
| 235 |
+
:param img: a ndarray, better a BGR
|
| 236 |
+
:param bgr_seq: new ch seq
|
| 237 |
+
:return: a ndarray
|
| 238 |
+
"""
|
| 239 |
+
return img.copy()[..., bgr_seq]
|
| 240 |
+
|
| 241 |
+
@staticmethod
|
| 242 |
+
def normalization(img: np.ndarray, rng: list):
|
| 243 |
+
"""normalize image according to min and max
|
| 244 |
+
:param img: a ndarray
|
| 245 |
+
:param rng: normalize image value to range[min, max]
|
| 246 |
+
:return: a ndarray
|
| 247 |
+
"""
|
| 248 |
+
lb, ub = rng
|
| 249 |
+
delta = ub - lb
|
| 250 |
+
return (img.copy().astype(np.float64) / 255.) * delta + lb#yjx
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class SpatialTrans:
|
| 254 |
+
"""
|
| 255 |
+
set of augmentations applied to the spatial space of image
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
@staticmethod
|
| 259 |
+
def rotate(img: np.ndarray, angle: int):
|
| 260 |
+
""" rotate image
|
| 261 |
+
# todo Square image and central rotate only, a universal version is needed
|
| 262 |
+
:param img: a ndarray
|
| 263 |
+
:param angle: rotate angle
|
| 264 |
+
:return: a ndarray has same size as input, padding zero or cut out region out of picture
|
| 265 |
+
"""
|
| 266 |
+
assert img.shape[0] == img.shape[1], "Square image needed."
|
| 267 |
+
center = (img.shape[0]/2, img.shape[1]/2)
|
| 268 |
+
mat = cv2.getRotationMatrix2D(center, angle, scale=1)
|
| 269 |
+
# mat = cv2.getRotationMatrix2D(tuple(np.array(img.shape[:2]) // 2), angle, scale=1)
|
| 270 |
+
return cv2.warpAffine(img.copy(), mat, img.shape[:2])
|
| 271 |
+
|
| 272 |
+
@staticmethod
|
| 273 |
+
def flip(img: np.ndarray, flip_axis: int):
|
| 274 |
+
"""flip image horizontal or vertical
|
| 275 |
+
:param img: a ndarray
|
| 276 |
+
:param flip_axis: 0 for horizontal, 1 for vertical
|
| 277 |
+
:return: a flipped image
|
| 278 |
+
"""
|
| 279 |
+
return cv2.flip(img.copy(), flip_axis)
|
| 280 |
+
|
| 281 |
+
@staticmethod
|
| 282 |
+
def scale(img: np.ndarray, ratio: float, fix_size: bool = False):
|
| 283 |
+
"""scale image
|
| 284 |
+
:param img: a ndarray
|
| 285 |
+
:param ratio: scale ratio
|
| 286 |
+
:param fix_size: return the center area of scaled image, size of area is same as the image before scaling
|
| 287 |
+
:return: a scaled image
|
| 288 |
+
"""
|
| 289 |
+
shape = img.shape[:2][::-1]
|
| 290 |
+
img = cv2.resize(img.copy(), None, fx=ratio, fy=ratio)
|
| 291 |
+
if fix_size:
|
| 292 |
+
img = SpatialTrans.center_crop(img, shape)
|
| 293 |
+
return img
|
| 294 |
+
|
| 295 |
+
@staticmethod
|
| 296 |
+
def crop(img: np.ndarray, bbox: tuple):
|
| 297 |
+
"""crop image according to given bbox
|
| 298 |
+
:param img: a ndarray
|
| 299 |
+
:param bbox: bbox of cropping area (x, y, w, h)
|
| 300 |
+
:return: cropped image,padding with zeros
|
| 301 |
+
"""
|
| 302 |
+
ch = [] if len(img.shape) == 2 else [img.shape[-1]]
|
| 303 |
+
template = np.zeros(list(bbox[-2:])[::-1] + ch)
|
| 304 |
+
|
| 305 |
+
if (bbox[1] >= img.shape[0] or bbox[1] >= img.shape[1]) or (bbox[0] + bbox[2] <= 0 or bbox[1] + bbox[3] <= 0):
|
| 306 |
+
logger.warning("Crop area contains nothing, return a zeros array {}".format(template.shape))
|
| 307 |
+
return template
|
| 308 |
+
|
| 309 |
+
foreground = img[
|
| 310 |
+
np.maximum(bbox[1], 0): np.minimum(bbox[1] + bbox[3], img.shape[0]),
|
| 311 |
+
np.maximum(bbox[0], 0): np.minimum(bbox[0] + bbox[2], img.shape[1]), :]
|
| 312 |
+
|
| 313 |
+
template[
|
| 314 |
+
np.maximum(-bbox[1], 0): np.minimum(-bbox[1] + img.shape[0], bbox[3]),
|
| 315 |
+
np.maximum(-bbox[0], 0): np.minimum(-bbox[0] + img.shape[1], bbox[2]), :] = foreground
|
| 316 |
+
return template.astype(np.uint8)
|
| 317 |
+
|
| 318 |
+
@staticmethod
|
| 319 |
+
def center_crop(img: np.ndarray, shape: tuple):
|
| 320 |
+
"""return the center area in shape
|
| 321 |
+
:param img: a ndarray
|
| 322 |
+
:param shape: center crop shape (w, h)
|
| 323 |
+
:return:
|
| 324 |
+
"""
|
| 325 |
+
center = np.array(img.shape[:2]) // 2
|
| 326 |
+
init = center[::-1] - np.array(shape) // 2
|
| 327 |
+
bbox = tuple(init.tolist() + list(shape))
|
| 328 |
+
return SpatialTrans.crop(img, bbox)
|
backend/best2.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1dc58852bdf5287554846eacd4d27daf757dbbcb111cec21e7df2bc463401e5e
|
| 3 |
+
size 6236202
|
backend/histopathology_trained_model.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:62e427f5d52567b488b157fa1aa8e3ef8236434b1b3752ca49d8a46c144d90c1
|
| 3 |
+
size 8069694
|
backend/logistic_regression_model.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:09c3442ef1cec5614ed17bbd9e1a4a1f8e38c588fdea13b2171b6662ace86c7b
|
| 3 |
+
size 94559
|
backend/model.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.utils.checkpoint as checkpoint
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from thop import profile
|
| 8 |
+
|
| 9 |
+
class IRB(nn.Module):
|
| 10 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, ksize=3, act_layer=nn.Hardswish, drop=0.):
|
| 11 |
+
super().__init__()
|
| 12 |
+
out_features = out_features or in_features
|
| 13 |
+
hidden_features = hidden_features or in_features
|
| 14 |
+
self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0)
|
| 15 |
+
self.act = act_layer()
|
| 16 |
+
self.conv = nn.Conv2d(hidden_features, hidden_features, kernel_size=ksize, padding=ksize // 2, stride=1,
|
| 17 |
+
groups=hidden_features)
|
| 18 |
+
self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0)
|
| 19 |
+
self.drop = nn.Dropout(drop)
|
| 20 |
+
|
| 21 |
+
def forward(self, x, H, W):
|
| 22 |
+
B, N, C = x.shape
|
| 23 |
+
x = x.permute(0, 2, 1).reshape(B, C, H, W)
|
| 24 |
+
x = self.fc1(x)
|
| 25 |
+
x = self.act(x)
|
| 26 |
+
x = self.conv(x)
|
| 27 |
+
x = self.act(x)
|
| 28 |
+
x = self.fc2(x)
|
| 29 |
+
return x.reshape(B, C, -1).permute(0, 2, 1)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def drop_path_f(x, drop_prob: float = 0., training: bool = False):
|
| 33 |
+
|
| 34 |
+
if drop_prob == 0. or not training:
|
| 35 |
+
return x
|
| 36 |
+
keep_prob = 1 - drop_prob
|
| 37 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 38 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 39 |
+
random_tensor.floor_() # binarize
|
| 40 |
+
output = x.div(keep_prob) * random_tensor
|
| 41 |
+
return output
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DropPath(nn.Module):
|
| 45 |
+
|
| 46 |
+
def __init__(self, drop_prob=None):
|
| 47 |
+
super(DropPath, self).__init__()
|
| 48 |
+
self.drop_prob = drop_prob
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
return drop_path_f(x, self.drop_prob, self.training)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def window_partition(x, window_size: int):
|
| 55 |
+
B, H, W, C = x.shape
|
| 56 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
| 57 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
| 58 |
+
return windows
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def window_reverse(windows, window_size: int, H: int, W: int):
|
| 62 |
+
|
| 63 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
| 64 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
| 65 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 66 |
+
return x
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class PatchEmbed2(nn.Module):
|
| 70 |
+
def __init__(self, dim:int, patch_size=2, in_c=3, norm_layer=None):
|
| 71 |
+
super().__init__()
|
| 72 |
+
patch_size = (patch_size, patch_size)
|
| 73 |
+
self.patch_size = patch_size
|
| 74 |
+
self.in_chans = in_c
|
| 75 |
+
self.embed_dim = dim
|
| 76 |
+
self.proj = nn.Conv2d(dim, 2*dim, kernel_size=patch_size, stride=patch_size)
|
| 77 |
+
self.norm = norm_layer(2*dim) if norm_layer else nn.Identity()
|
| 78 |
+
|
| 79 |
+
def forward(self, x, H, W):
|
| 80 |
+
B, L, C = x.shape
|
| 81 |
+
assert L == H * W, "input feature has wrong size"
|
| 82 |
+
|
| 83 |
+
x = x.view(B, H, W, C)
|
| 84 |
+
pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
|
| 85 |
+
if pad_input:
|
| 86 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
|
| 87 |
+
0, self.patch_size[0] - H % self.patch_size[0],
|
| 88 |
+
0, 0))
|
| 89 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2)
|
| 90 |
+
x = self.proj(x)
|
| 91 |
+
_, _, H, W = x.shape
|
| 92 |
+
x = x.flatten(2).transpose(1, 2)
|
| 93 |
+
x = self.norm(x)
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class PatchEmbed(nn.Module):
|
| 98 |
+
def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
|
| 99 |
+
super().__init__()
|
| 100 |
+
patch_size = (patch_size, patch_size)
|
| 101 |
+
self.patch_size = patch_size
|
| 102 |
+
self.in_chans = in_c
|
| 103 |
+
self.embed_dim = embed_dim
|
| 104 |
+
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 105 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
_, _, H, W = x.shape
|
| 109 |
+
|
| 110 |
+
# padding
|
| 111 |
+
# 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
|
| 112 |
+
pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
|
| 113 |
+
if pad_input:
|
| 114 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
|
| 115 |
+
0, self.patch_size[0] - H % self.patch_size[0],
|
| 116 |
+
0, 0))
|
| 117 |
+
|
| 118 |
+
# 下采样patch_size倍
|
| 119 |
+
x = self.proj(x)
|
| 120 |
+
_, _, H, W = x.shape
|
| 121 |
+
x = x.flatten(2).transpose(1, 2)
|
| 122 |
+
x = self.norm(x)
|
| 123 |
+
return x, H, W
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class PatchMerging(nn.Module):
|
| 127 |
+
|
| 128 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.dim = dim
|
| 131 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
| 132 |
+
self.norm = norm_layer(4 * dim)
|
| 133 |
+
|
| 134 |
+
def forward(self, x, H, W):
|
| 135 |
+
B, L, C = x.shape
|
| 136 |
+
assert L == H * W, "input feature has wrong size"
|
| 137 |
+
|
| 138 |
+
x = x.view(B, H, W, C)
|
| 139 |
+
|
| 140 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
| 141 |
+
if pad_input:
|
| 142 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
| 143 |
+
|
| 144 |
+
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
|
| 145 |
+
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
|
| 146 |
+
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
|
| 147 |
+
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
|
| 148 |
+
x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]
|
| 149 |
+
x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]
|
| 150 |
+
|
| 151 |
+
x = self.norm(x)
|
| 152 |
+
x = self.reduction(x) # [B, H/2*W/2, 2*C]
|
| 153 |
+
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class Mlp(nn.Module):
|
| 158 |
+
|
| 159 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 160 |
+
super().__init__()
|
| 161 |
+
out_features = out_features or in_features
|
| 162 |
+
hidden_features = hidden_features or in_features
|
| 163 |
+
|
| 164 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 165 |
+
self.act = act_layer()
|
| 166 |
+
self.drop1 = nn.Dropout(drop)
|
| 167 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 168 |
+
self.drop2 = nn.Dropout(drop)
|
| 169 |
+
|
| 170 |
+
def forward(self, x):
|
| 171 |
+
x = self.fc1(x)
|
| 172 |
+
x = self.act(x)
|
| 173 |
+
x = self.drop1(x)
|
| 174 |
+
x = self.fc2(x)
|
| 175 |
+
x = self.drop2(x)
|
| 176 |
+
return x
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class WindowAttention(nn.Module):
|
| 180 |
+
|
| 181 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
|
| 182 |
+
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.dim = dim
|
| 185 |
+
self.window_size = window_size # [Mh, Mw]
|
| 186 |
+
self.num_heads = num_heads
|
| 187 |
+
head_dim = dim // num_heads
|
| 188 |
+
self.scale = head_dim ** -0.5
|
| 189 |
+
|
| 190 |
+
# define a parameter table of relative position bias
|
| 191 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 192 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # [2*Mh-1 * 2*Mw-1, nH]
|
| 193 |
+
|
| 194 |
+
# get pair-wise relative position index for each token inside the window
|
| 195 |
+
coords_h = torch.arange(self.window_size[0])
|
| 196 |
+
coords_w = torch.arange(self.window_size[1])
|
| 197 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
| 198 |
+
coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw]
|
| 199 |
+
# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
|
| 200 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw]
|
| 201 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh*Mw, Mh*Mw, 2]
|
| 202 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
| 203 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
| 204 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
| 205 |
+
relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw]
|
| 206 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
| 207 |
+
|
| 208 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 209 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 210 |
+
self.proj = nn.Linear(dim, dim)
|
| 211 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 212 |
+
|
| 213 |
+
nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 214 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 215 |
+
|
| 216 |
+
def forward(self, x, mask: Optional[torch.Tensor] = None):
|
| 217 |
+
B_, N, C = x.shape
|
| 218 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 219 |
+
# [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
|
| 220 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 221 |
+
|
| 222 |
+
# transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
|
| 223 |
+
# @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
|
| 224 |
+
q = q * self.scale
|
| 225 |
+
attn = (q @ k.transpose(-2, -1))
|
| 226 |
+
|
| 227 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 228 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
|
| 229 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # [nH, Mh*Mw, Mh*Mw]
|
| 230 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
| 231 |
+
|
| 232 |
+
if mask is not None:
|
| 233 |
+
nW = mask.shape[0] # num_windows
|
| 234 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
| 235 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
| 236 |
+
attn = self.softmax(attn)
|
| 237 |
+
else:
|
| 238 |
+
attn = self.softmax(attn)
|
| 239 |
+
|
| 240 |
+
attn = self.attn_drop(attn)
|
| 241 |
+
|
| 242 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
| 243 |
+
x = self.proj(x)
|
| 244 |
+
x = self.proj_drop(x)
|
| 245 |
+
return x
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class TransformerBlock(nn.Module):
|
| 249 |
+
def __init__(self, dim, num_heads, window_sizes=(7,4,2), branch_num=3,
|
| 250 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
|
| 251 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 252 |
+
super().__init__()
|
| 253 |
+
self.dim = dim
|
| 254 |
+
self.num_heads = num_heads
|
| 255 |
+
self.window_sizes = window_sizes
|
| 256 |
+
self.branch_num = branch_num
|
| 257 |
+
self.mlp_ratio = mlp_ratio
|
| 258 |
+
|
| 259 |
+
self.norm1 = norm_layer(dim)
|
| 260 |
+
self.attn = WindowAttention(
|
| 261 |
+
dim//branch_num, window_size=(self.window_sizes[0], self.window_sizes[0]), num_heads=num_heads//branch_num, qkv_bias=qkv_bias,
|
| 262 |
+
attn_drop=attn_drop, proj_drop=drop)
|
| 263 |
+
self.attn1 = WindowAttention(
|
| 264 |
+
dim//branch_num, window_size=(self.window_sizes[1], self.window_sizes[1]), num_heads=num_heads//branch_num, qkv_bias=qkv_bias,
|
| 265 |
+
attn_drop=attn_drop, proj_drop=drop)
|
| 266 |
+
self.attn2 = WindowAttention(
|
| 267 |
+
dim//branch_num, window_size=(self.window_sizes[2], self.window_sizes[2]), num_heads=num_heads//branch_num, qkv_bias=qkv_bias,
|
| 268 |
+
attn_drop=attn_drop, proj_drop=drop)
|
| 269 |
+
|
| 270 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 271 |
+
self.norm2 = norm_layer(dim)
|
| 272 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 273 |
+
self.mlp = IRB(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 274 |
+
|
| 275 |
+
def forward(self, x, attn_mask):
|
| 276 |
+
H, W = self.H, self.W
|
| 277 |
+
B, L, C = x.shape
|
| 278 |
+
assert L == H * W, "input feature has wrong size"
|
| 279 |
+
|
| 280 |
+
shortcut = x
|
| 281 |
+
x = self.norm1(x)
|
| 282 |
+
x = x.view(B, H, W, C)
|
| 283 |
+
x0 = x[:,:,:,:(C//self.branch_num)]
|
| 284 |
+
x1 = x[:,:,:,(C//self.branch_num):(2*C//self.branch_num)]
|
| 285 |
+
x2 = x[:,:,:,(2*C//self.branch_num):]
|
| 286 |
+
# ----------------------------------------------------------------------------------------------
|
| 287 |
+
pad_l = pad_t = 0
|
| 288 |
+
pad_r = (self.window_sizes[0] - W % self.window_sizes[0]) % self.window_sizes[0]
|
| 289 |
+
pad_b = (self.window_sizes[0] - H % self.window_sizes[0]) % self.window_sizes[0]
|
| 290 |
+
x0 = F.pad(x0, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 291 |
+
_, Hp, Wp, _ = x0.shape
|
| 292 |
+
attn_mask = None
|
| 293 |
+
|
| 294 |
+
# partition windows
|
| 295 |
+
x_windows = window_partition(x0, self.window_sizes[0]) # [nW*B, Mh, Mw, C]
|
| 296 |
+
x_windows = x_windows.view(-1, self.window_sizes[0] * self.window_sizes[0], C//self.branch_num) # [nW*B, Mh*Mw, C]
|
| 297 |
+
|
| 298 |
+
# W-MSA/SW-MSA
|
| 299 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C]
|
| 300 |
+
|
| 301 |
+
# merge windows
|
| 302 |
+
attn_windows = attn_windows.view(-1, self.window_sizes[0], self.window_sizes[0], C//self.branch_num) # [nW*B, Mh, Mw, C]
|
| 303 |
+
x0 = window_reverse(attn_windows, self.window_sizes[0], Hp, Wp) # [B, H', W', C]
|
| 304 |
+
|
| 305 |
+
if pad_r > 0 or pad_b > 0:
|
| 306 |
+
# 把前面pad的数据移除掉
|
| 307 |
+
x0 = x0[:, :H, :W, :].contiguous()
|
| 308 |
+
|
| 309 |
+
x0 = x0.view(B, H * W, C//self.branch_num)
|
| 310 |
+
# ----------------------------------------------------------------------------------------------
|
| 311 |
+
pad_l = pad_t = 0
|
| 312 |
+
pad_r = (self.window_sizes[1] - W % self.window_sizes[1]) % self.window_sizes[1]
|
| 313 |
+
pad_b = (self.window_sizes[1] - H % self.window_sizes[1]) % self.window_sizes[1]
|
| 314 |
+
x1 = F.pad(x1, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 315 |
+
_, Hp, Wp, _ = x1.shape
|
| 316 |
+
attn_mask = None
|
| 317 |
+
|
| 318 |
+
# partition windows
|
| 319 |
+
x_windows = window_partition(x1, self.window_sizes[1]) # [nW*B, Mh, Mw, C]
|
| 320 |
+
x_windows = x_windows.view(-1, self.window_sizes[1] * self.window_sizes[1],
|
| 321 |
+
C // self.branch_num) # [nW*B, Mh*Mw, C]
|
| 322 |
+
|
| 323 |
+
# W-MSA/SW-MSA
|
| 324 |
+
attn_windows = self.attn1(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C]
|
| 325 |
+
|
| 326 |
+
# merge windows
|
| 327 |
+
attn_windows = attn_windows.view(-1, self.window_sizes[1], self.window_sizes[1],
|
| 328 |
+
C // self.branch_num) # [nW*B, Mh, Mw, C]
|
| 329 |
+
x1 = window_reverse(attn_windows, self.window_sizes[1], Hp, Wp) # [B, H', W', C]
|
| 330 |
+
|
| 331 |
+
if pad_r > 0 or pad_b > 0:
|
| 332 |
+
# 把前面pad的数据移除掉
|
| 333 |
+
x1 = x1[:, :H, :W, :].contiguous()
|
| 334 |
+
x1 = x1.view(B, H * W, C // self.branch_num)
|
| 335 |
+
# ----------------------------------------------------------------------------------------------
|
| 336 |
+
pad_l = pad_t = 0
|
| 337 |
+
pad_r = (self.window_sizes[2] - W % self.window_sizes[2]) % self.window_sizes[2]
|
| 338 |
+
pad_b = (self.window_sizes[2] - H % self.window_sizes[2]) % self.window_sizes[2]
|
| 339 |
+
x2 = F.pad(x2, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
| 340 |
+
_, Hp, Wp, _ = x2.shape
|
| 341 |
+
attn_mask = None
|
| 342 |
+
x_windows = window_partition(x2, self.window_sizes[2]) # [nW*B, Mh, Mw, C]
|
| 343 |
+
x_windows = x_windows.view(-1, self.window_sizes[2] * self.window_sizes[2],
|
| 344 |
+
C // self.branch_num) # [nW*B, Mh*Mw, C]
|
| 345 |
+
|
| 346 |
+
attn_windows = self.attn2(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C]
|
| 347 |
+
|
| 348 |
+
attn_windows = attn_windows.view(-1, self.window_sizes[2], self.window_sizes[2],
|
| 349 |
+
C // self.branch_num) # [nW*B, Mh, Mw, C]
|
| 350 |
+
x2 = window_reverse(attn_windows, self.window_sizes[2], Hp, Wp) # [B, H', W', C]
|
| 351 |
+
|
| 352 |
+
if pad_r > 0 or pad_b > 0:
|
| 353 |
+
x2 = x2[:, :H, :W, :].contiguous()
|
| 354 |
+
|
| 355 |
+
x2 = x2.view(B, H * W, C // self.branch_num)
|
| 356 |
+
|
| 357 |
+
x = torch.cat([x0, x1, x2], -1)
|
| 358 |
+
# FFN
|
| 359 |
+
x = shortcut + self.drop_path(x)
|
| 360 |
+
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
| 361 |
+
|
| 362 |
+
return x
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class BasicLayer(nn.Module):
|
| 366 |
+
def __init__(self, dim, depth, num_heads, window_size,
|
| 367 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
| 368 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
| 369 |
+
super().__init__()
|
| 370 |
+
self.dim = dim
|
| 371 |
+
self.depth = depth
|
| 372 |
+
self.window_size = window_size
|
| 373 |
+
self.use_checkpoint = use_checkpoint
|
| 374 |
+
self.shift_size = window_size // 2
|
| 375 |
+
|
| 376 |
+
# build blocks
|
| 377 |
+
self.blocks = nn.ModuleList([
|
| 378 |
+
TransformerBlock(
|
| 379 |
+
dim=dim,
|
| 380 |
+
num_heads=num_heads,
|
| 381 |
+
window_sizes=(7,4,2),
|
| 382 |
+
mlp_ratio=mlp_ratio,
|
| 383 |
+
qkv_bias=qkv_bias,
|
| 384 |
+
drop=drop,
|
| 385 |
+
attn_drop=attn_drop,
|
| 386 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 387 |
+
norm_layer=norm_layer)
|
| 388 |
+
for i in range(depth)])
|
| 389 |
+
|
| 390 |
+
# patch merging layer
|
| 391 |
+
if downsample is not None:
|
| 392 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
| 393 |
+
else:
|
| 394 |
+
self.downsample = None
|
| 395 |
+
|
| 396 |
+
def create_mask(self, x, H, W):
|
| 397 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
| 398 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
| 399 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # [1, Hp, Wp, 1]
|
| 400 |
+
h_slices = (slice(0, -self.window_size),
|
| 401 |
+
slice(-self.window_size, -self.shift_size),
|
| 402 |
+
slice(-self.shift_size, None))
|
| 403 |
+
w_slices = (slice(0, -self.window_size),
|
| 404 |
+
slice(-self.window_size, -self.shift_size),
|
| 405 |
+
slice(-self.shift_size, None))
|
| 406 |
+
cnt = 0
|
| 407 |
+
for h in h_slices:
|
| 408 |
+
for w in w_slices:
|
| 409 |
+
img_mask[:, h, w, :] = cnt
|
| 410 |
+
cnt += 1
|
| 411 |
+
|
| 412 |
+
mask_windows = window_partition(img_mask, self.window_size) # [nW, Mh, Mw, 1]
|
| 413 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw]
|
| 414 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
|
| 415 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
| 416 |
+
return attn_mask
|
| 417 |
+
|
| 418 |
+
def forward(self, x, H, W):
|
| 419 |
+
attn_mask = self.create_mask(x, H, W) # [nW, Mh*Mw, Mh*Mw]
|
| 420 |
+
for blk in self.blocks:
|
| 421 |
+
blk.H, blk.W = H, W
|
| 422 |
+
if not torch.jit.is_scripting() and self.use_checkpoint:
|
| 423 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
| 424 |
+
else:
|
| 425 |
+
x = blk(x, attn_mask)
|
| 426 |
+
if self.downsample is not None:
|
| 427 |
+
x = self.downsample(x, H, W)
|
| 428 |
+
H, W = (H + 1) // 2, (W + 1) // 2
|
| 429 |
+
|
| 430 |
+
return x, H, W
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class Transformer(nn.Module):
|
| 434 |
+
def __init__(self, patch_size=4, in_chans=3, num_classes=1000,
|
| 435 |
+
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
|
| 436 |
+
window_size=7, mlp_ratio=4., qkv_bias=True,
|
| 437 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
| 438 |
+
norm_layer=nn.LayerNorm, patch_norm=True,
|
| 439 |
+
use_checkpoint=False, **kwargs):
|
| 440 |
+
super().__init__()
|
| 441 |
+
|
| 442 |
+
self.num_classes = num_classes
|
| 443 |
+
self.num_layers = len(depths)
|
| 444 |
+
self.embed_dim = embed_dim
|
| 445 |
+
self.patch_norm = patch_norm
|
| 446 |
+
# stage4输出特征矩阵的channels
|
| 447 |
+
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
| 448 |
+
self.mlp_ratio = mlp_ratio
|
| 449 |
+
|
| 450 |
+
self.patch_embed = PatchEmbed(
|
| 451 |
+
patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
|
| 452 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
| 453 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 454 |
+
|
| 455 |
+
# stochastic depth
|
| 456 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
| 457 |
+
|
| 458 |
+
self.layers = nn.ModuleList()
|
| 459 |
+
for i_layer in range(self.num_layers):
|
| 460 |
+
layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
|
| 461 |
+
depth=depths[i_layer],
|
| 462 |
+
num_heads=num_heads[i_layer],
|
| 463 |
+
window_size=window_size,
|
| 464 |
+
mlp_ratio=self.mlp_ratio,
|
| 465 |
+
qkv_bias=qkv_bias,
|
| 466 |
+
drop=drop_rate,
|
| 467 |
+
attn_drop=attn_drop_rate,
|
| 468 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
| 469 |
+
norm_layer=norm_layer,
|
| 470 |
+
downsample=PatchEmbed2 if (i_layer < self.num_layers - 1) else None,
|
| 471 |
+
use_checkpoint=use_checkpoint)
|
| 472 |
+
self.layers.append(layers)
|
| 473 |
+
|
| 474 |
+
self.norm = norm_layer(self.num_features)
|
| 475 |
+
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
| 476 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 477 |
+
|
| 478 |
+
self.apply(self._init_weights)
|
| 479 |
+
|
| 480 |
+
def _init_weights(self, m):
|
| 481 |
+
if isinstance(m, nn.Linear):
|
| 482 |
+
nn.init.trunc_normal_(m.weight, std=.02)
|
| 483 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 484 |
+
nn.init.constant_(m.bias, 0)
|
| 485 |
+
elif isinstance(m, nn.LayerNorm):
|
| 486 |
+
nn.init.constant_(m.bias, 0)
|
| 487 |
+
nn.init.constant_(m.weight, 1.0)
|
| 488 |
+
|
| 489 |
+
def forward(self, x):
|
| 490 |
+
# x: [B, L, C]
|
| 491 |
+
x, H, W = self.patch_embed(x)
|
| 492 |
+
x = self.pos_drop(x)
|
| 493 |
+
|
| 494 |
+
for layer in self.layers:
|
| 495 |
+
x, H, W = layer(x, H, W)
|
| 496 |
+
|
| 497 |
+
x = self.norm(x) # [B, L, C]
|
| 498 |
+
x = self.avgpool(x.transpose(1, 2)) # [B, C, 1]
|
| 499 |
+
x = torch.flatten(x, 1)
|
| 500 |
+
x = self.head(x)
|
| 501 |
+
return x
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def MWT(num_classes: int = 1000, **kwargs):
|
| 505 |
+
model = Transformer(in_chans=3,
|
| 506 |
+
patch_size=4,
|
| 507 |
+
# window_sizes=(7,4,2),
|
| 508 |
+
embed_dim=96,
|
| 509 |
+
depths=(2, 4, 4, 2),
|
| 510 |
+
num_heads=(3, 6, 12, 24),
|
| 511 |
+
num_classes=num_classes,
|
| 512 |
+
**kwargs)
|
| 513 |
+
return model
|
| 514 |
+
|
| 515 |
+
if __name__ == '__main__':
|
| 516 |
+
model = MWT(num_classes=2)
|
| 517 |
+
input = torch.randn(1, 3, 224, 224)
|
| 518 |
+
flops, params = profile(model, inputs=(input,))
|
| 519 |
+
print(flops)
|
| 520 |
+
print(params)
|
| 521 |
+
|
backend/model_histo.py
ADDED
|
@@ -0,0 +1,1495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Breast Cancer Histopathology Classification using Path Foundation Model
|
| 3 |
+
|
| 4 |
+
This module implements a comprehensive deep learning pipeline for breast cancer classification
|
| 5 |
+
from histopathology images using Google's Path Foundation model as a feature extractor. The
|
| 6 |
+
system supports multiple datasets including BreakHis, PatchCamelyon (PCam), and BACH, employing
|
| 7 |
+
transfer learning to achieve high classification accuracy.
|
| 8 |
+
|
| 9 |
+
**Overview:**
|
| 10 |
+
This system leverages Google's Path Foundation model, which is pre-trained on a large corpus
|
| 11 |
+
of pathology images, to extract meaningful features from breast cancer histopathology images.
|
| 12 |
+
The approach uses transfer learning where the foundation model serves as a frozen feature
|
| 13 |
+
extractor, followed by a trainable classification head for binary classification (benign vs malignant).
|
| 14 |
+
|
| 15 |
+
**Model Architecture:**
|
| 16 |
+
- Foundation Model: Google's Path Foundation (pre-trained on pathology images)
|
| 17 |
+
- Transfer Learning Approach: Feature extraction with frozen foundation model + trainable classifier head
|
| 18 |
+
- Classification Head: Multi-layer dense network with regularisation and dropout
|
| 19 |
+
- Optimisation: AdamW optimiser with learning rate scheduling and early stopping
|
| 20 |
+
|
| 21 |
+
**Workflow:**
|
| 22 |
+
1. Authentication & Model Loading: Authenticate with Hugging Face and load Path Foundation
|
| 23 |
+
2. Data Loading: Load and preprocess histopathology datasets
|
| 24 |
+
3. Feature Extraction: Extract embeddings using frozen foundation model
|
| 25 |
+
4. Classifier Training: Train dense neural network on extracted features
|
| 26 |
+
5. Evaluation: Comprehensive performance analysis with multiple metrics and visualisations
|
| 27 |
+
|
| 28 |
+
**Supported Datasets:**
|
| 29 |
+
- BreakHis: Breast cancer histopathology images at multiple magnifications
|
| 30 |
+
- PatchCamelyon (PCam): Lymph node metastasis detection patches
|
| 31 |
+
- BACH: ICIAR 2018 Breast Cancer Histology Challenge dataset
|
| 32 |
+
- Combined: Ensemble of all three datasets for robust training
|
| 33 |
+
|
| 34 |
+
**Key Features:**
|
| 35 |
+
- Multiple dataset support with consistent pre-processing
|
| 36 |
+
- Robust error handling and fallback mechanisms
|
| 37 |
+
- Comprehensive evaluation metrics and visualisation
|
| 38 |
+
- Memory-efficient batch processing
|
| 39 |
+
- Data augmentation capabilities
|
| 40 |
+
- Model persistence and deployment support
|
| 41 |
+
|
| 42 |
+
Author: Research Team
|
| 43 |
+
Date: 2024
|
| 44 |
+
License: MIT
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
# Import required libraries and configure environment
|
| 48 |
+
import os
|
| 49 |
+
import tensorflow as tf
|
| 50 |
+
import numpy as np
|
| 51 |
+
from PIL import Image
|
| 52 |
+
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
|
| 53 |
+
from pathlib import Path
|
| 54 |
+
import h5py
|
| 55 |
+
from sklearn.model_selection import train_test_split
|
| 56 |
+
from sklearn.utils.class_weight import compute_class_weight
|
| 57 |
+
from tensorflow.keras import regularizers
|
| 58 |
+
import matplotlib
|
| 59 |
+
# Use a non-interactive backend to prevent blocking on plt.show()
|
| 60 |
+
matplotlib.use('Agg')
|
| 61 |
+
import matplotlib.pyplot as plt
|
| 62 |
+
import seaborn as sns
|
| 63 |
+
|
| 64 |
+
# Suppress TensorFlow logging for cleaner output
|
| 65 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
| 66 |
+
|
| 67 |
+
# Configure TensorFlow logging for cleaner output
|
| 68 |
+
try:
|
| 69 |
+
tf.get_logger().setLevel('ERROR')
|
| 70 |
+
except AttributeError:
|
| 71 |
+
import logging
|
| 72 |
+
logging.getLogger('tensorflow').setLevel(logging.ERROR)
|
| 73 |
+
|
| 74 |
+
# Configure Hugging Face Hub integration with fallback mechanisms
|
| 75 |
+
# This section handles the loading of Google's Path Foundation model from Hugging Face Hub
|
| 76 |
+
# with multiple fallback methods to ensure compatibility across different environments
|
| 77 |
+
try:
|
| 78 |
+
from huggingface_hub import login, hf_hub_download, snapshot_download
|
| 79 |
+
|
| 80 |
+
# Try different methods for loading Keras models from HF Hub
|
| 81 |
+
# Method 1: Direct Keras loading (preferred)
|
| 82 |
+
try:
|
| 83 |
+
from huggingface_hub import from_pretrained_keras
|
| 84 |
+
KERAS_METHOD = "from_pretrained_keras"
|
| 85 |
+
except ImportError:
|
| 86 |
+
# Method 2: Transformers library fallback
|
| 87 |
+
try:
|
| 88 |
+
from transformers import TFAutoModel
|
| 89 |
+
KERAS_METHOD = "transformers"
|
| 90 |
+
except ImportError:
|
| 91 |
+
# Method 3: Manual download and TFSMLayer
|
| 92 |
+
KERAS_METHOD = "manual"
|
| 93 |
+
|
| 94 |
+
HF_AVAILABLE = True
|
| 95 |
+
print(f"Hugging Face Hub loaded successfully (method: {KERAS_METHOD})")
|
| 96 |
+
except ImportError as e:
|
| 97 |
+
print(f"Hugging Face Hub unavailable: {e}")
|
| 98 |
+
print("Please install required packages: pip install huggingface_hub transformers")
|
| 99 |
+
HF_AVAILABLE = False
|
| 100 |
+
KERAS_METHOD = None
|
| 101 |
+
|
| 102 |
+
class BreastCancerClassifier:
|
| 103 |
+
"""
|
| 104 |
+
A comprehensive breast cancer classification system using Path Foundation model.
|
| 105 |
+
|
| 106 |
+
This class implements a transfer learning approach where Google's Path Foundation
|
| 107 |
+
model serves as a feature extractor, followed by a trainable classification head.
|
| 108 |
+
The system supports both feature extraction (frozen foundation model) and
|
| 109 |
+
fine-tuning approaches for maximum flexibility.
|
| 110 |
+
|
| 111 |
+
The classifier can work with multiple histopathology datasets and provides
|
| 112 |
+
comprehensive evaluation capabilities including confusion matrices, classification
|
| 113 |
+
reports, and performance metrics.
|
| 114 |
+
|
| 115 |
+
Attributes:
|
| 116 |
+
fine_tune (bool): Whether to fine-tune the foundation model or use it frozen
|
| 117 |
+
model (tf.keras.Model): The complete classification model
|
| 118 |
+
path_foundation: The loaded Path Foundation model from Hugging Face Hub
|
| 119 |
+
history: Training history from model.fit() containing loss and accuracy curves
|
| 120 |
+
embedding_dim (int): Dimensionality of extracted embeddings from foundation model
|
| 121 |
+
num_classes (int): Number of output classes (default: 2 for binary classification)
|
| 122 |
+
|
| 123 |
+
Example:
|
| 124 |
+
>>> classifier = BreastCancerClassifier(fine_tune=False)
|
| 125 |
+
>>> classifier.authenticate_huggingface()
|
| 126 |
+
>>> classifier.load_path_foundation()
|
| 127 |
+
>>> # Load data and train...
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(self, fine_tune=False):
|
| 131 |
+
"""
|
| 132 |
+
Initialise the breast cancer classifier.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
fine_tune (bool): If True, allows fine-tuning of foundation model.
|
| 136 |
+
If False, uses foundation model as frozen feature extractor.
|
| 137 |
+
|
| 138 |
+
Note: Fine-tuning requires more computational resources and
|
| 139 |
+
may lead to overfitting on smaller datasets. Feature extraction
|
| 140 |
+
(fine_tune=False) is recommended for most use-cases.
|
| 141 |
+
"""
|
| 142 |
+
self.fine_tune = fine_tune
|
| 143 |
+
self.model = None
|
| 144 |
+
self.path_foundation = None
|
| 145 |
+
self.history = None
|
| 146 |
+
self.embedding_dim = None
|
| 147 |
+
self.num_classes = 2 # Binary classification: benign vs malignant
|
| 148 |
+
|
| 149 |
+
def authenticate_huggingface(self, token=None):
|
| 150 |
+
"""
|
| 151 |
+
Authenticate with Hugging Face Hub to access Path Foundation model.
|
| 152 |
+
|
| 153 |
+
This method handles authentication with Hugging Face Hub, which is required
|
| 154 |
+
to download and use Google's Path Foundation model. It supports multiple
|
| 155 |
+
token sources and provides fallback mechanisms.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
token (str, optional): Hugging Face access token. If None, the method
|
| 159 |
+
will attempt to use environment variables:
|
| 160 |
+
- HF_TOKEN
|
| 161 |
+
- HUGGINGFACE_HUB_TOKEN
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
bool: True if authentication successful, False otherwise
|
| 165 |
+
|
| 166 |
+
Note:
|
| 167 |
+
You can obtain a Hugging Face token by:
|
| 168 |
+
1. Creating an account at https://huggingface.co
|
| 169 |
+
2. Going to Settings > Access Tokens
|
| 170 |
+
3. Creating a new token with read permissions
|
| 171 |
+
|
| 172 |
+
Example:
|
| 173 |
+
>>> classifier = BreastCancerClassifier()
|
| 174 |
+
>>> success = classifier.authenticate_huggingface("hf_xxxxxxxxxxxx")
|
| 175 |
+
>>> if success:
|
| 176 |
+
... print("Authentication successful")
|
| 177 |
+
"""
|
| 178 |
+
if not HF_AVAILABLE:
|
| 179 |
+
print("Cannot authenticate - Hugging Face Hub not available")
|
| 180 |
+
return False
|
| 181 |
+
|
| 182 |
+
try:
|
| 183 |
+
# Try multiple token sources: parameter, environment variables
|
| 184 |
+
final_token = token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
|
| 185 |
+
|
| 186 |
+
if final_token:
|
| 187 |
+
login(token=final_token, add_to_git_credential=False)
|
| 188 |
+
print("Hugging Face authentication successful")
|
| 189 |
+
return True
|
| 190 |
+
else:
|
| 191 |
+
print("No token provided, attempting to use cached login")
|
| 192 |
+
return True
|
| 193 |
+
except Exception as e:
|
| 194 |
+
print(f"Authentication failed: {e}")
|
| 195 |
+
return False
|
| 196 |
+
|
| 197 |
+
def load_path_foundation(self):
|
| 198 |
+
"""
|
| 199 |
+
Load Google's Path Foundation model with multiple fallback mechanisms.
|
| 200 |
+
|
| 201 |
+
This method attempts to load the Path Foundation model using three different
|
| 202 |
+
approaches to ensure maximum compatibility across different environments:
|
| 203 |
+
|
| 204 |
+
1. Direct Keras loading via huggingface_hub (preferred)
|
| 205 |
+
2. Transformers library loading (fallback)
|
| 206 |
+
3. Manual download and TFSMLayer loading (last resort)
|
| 207 |
+
|
| 208 |
+
The method also configures the model's training behavior based on the
|
| 209 |
+
fine_tune parameter set during initialization.
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
bool: True if model loaded successfully, False otherwise
|
| 213 |
+
|
| 214 |
+
Raises:
|
| 215 |
+
Various exceptions may be raised during the loading process, but they
|
| 216 |
+
are caught and handled gracefully with informative error messages.
|
| 217 |
+
|
| 218 |
+
Note:
|
| 219 |
+
The Path Foundation model is a large pre-trained model (~1GB) that will
|
| 220 |
+
be downloaded on first use. Subsequent runs will use the cached version.
|
| 221 |
+
|
| 222 |
+
Example:
|
| 223 |
+
>>> classifier = BreastCancerClassifier(fine_tune=False)
|
| 224 |
+
>>> if classifier.load_path_foundation():
|
| 225 |
+
... print("Model loaded successfully")
|
| 226 |
+
... else:
|
| 227 |
+
... print("Failed to load model")
|
| 228 |
+
"""
|
| 229 |
+
if not HF_AVAILABLE:
|
| 230 |
+
print("Cannot load model - Hugging Face Hub unavailable")
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
print("Loading Path Foundation model...")
|
| 235 |
+
loaded = False
|
| 236 |
+
|
| 237 |
+
# Method 1: Direct Keras loading (preferred method)
|
| 238 |
+
if KERAS_METHOD == "from_pretrained_keras":
|
| 239 |
+
try:
|
| 240 |
+
self.path_foundation = from_pretrained_keras("google/path-foundation")
|
| 241 |
+
loaded = True
|
| 242 |
+
print("Successfully loaded via from_pretrained_keras")
|
| 243 |
+
except Exception as e:
|
| 244 |
+
print(f"Keras loading failed: {e}")
|
| 245 |
+
|
| 246 |
+
# Method 2: Transformers library fallback
|
| 247 |
+
if not loaded and KERAS_METHOD == "transformers":
|
| 248 |
+
try:
|
| 249 |
+
print("Attempting transformers fallback...")
|
| 250 |
+
self.path_foundation = TFAutoModel.from_pretrained("google/path-foundation")
|
| 251 |
+
loaded = True
|
| 252 |
+
print("Successfully loaded via transformers")
|
| 253 |
+
except Exception as e:
|
| 254 |
+
print(f"Transformers loading failed: {e}")
|
| 255 |
+
|
| 256 |
+
# Method 3: Manual download and TFSMLayer (last resort)
|
| 257 |
+
if not loaded:
|
| 258 |
+
try:
|
| 259 |
+
try:
|
| 260 |
+
import keras as _standalone_keras
|
| 261 |
+
except ImportError as _e:
|
| 262 |
+
print(f"Keras 3 not installed: {_e}")
|
| 263 |
+
return False
|
| 264 |
+
|
| 265 |
+
print("Attempting manual download and TFSMLayer loading...")
|
| 266 |
+
local_dir = snapshot_download(repo_id="google/path-foundation")
|
| 267 |
+
self.path_foundation = _standalone_keras.layers.TFSMLayer(
|
| 268 |
+
local_dir, call_endpoint="serving_default"
|
| 269 |
+
)
|
| 270 |
+
loaded = True
|
| 271 |
+
print("Successfully loaded via TFSMLayer")
|
| 272 |
+
except Exception as e:
|
| 273 |
+
print(f"TFSMLayer loading failed: {e}")
|
| 274 |
+
return False
|
| 275 |
+
|
| 276 |
+
# Configure training behavior based on fine_tune setting
|
| 277 |
+
if self.fine_tune:
|
| 278 |
+
self.path_foundation.trainable = True
|
| 279 |
+
try:
|
| 280 |
+
# Only fine-tune the last 3 layers for stability
|
| 281 |
+
for layer in self.path_foundation.layers[:-3]:
|
| 282 |
+
layer.trainable = False
|
| 283 |
+
print("Fine-tuning enabled: last 3 layers trainable")
|
| 284 |
+
except:
|
| 285 |
+
print("Fine-tuning enabled: full model trainable")
|
| 286 |
+
else:
|
| 287 |
+
self.path_foundation.trainable = False
|
| 288 |
+
print("Model frozen for feature extraction")
|
| 289 |
+
|
| 290 |
+
return True
|
| 291 |
+
|
| 292 |
+
except Exception as e:
|
| 293 |
+
print(f"Failed to load Path Foundation model: {e}")
|
| 294 |
+
return False
|
| 295 |
+
|
| 296 |
+
def preprocess_image_batch(self, images):
|
| 297 |
+
"""
|
| 298 |
+
Pre-process a batch of images for Path Foundation model input.
|
| 299 |
+
|
| 300 |
+
This method handles multiple input formats and ensures all images are properly
|
| 301 |
+
formatted for the Path Foundation model. It performs the following operations:
|
| 302 |
+
- Resizes all images to 224x224 pixels (required by Path Foundation)
|
| 303 |
+
- Converts images to RGB format
|
| 304 |
+
- Normalises pixel values to [0, 1] range
|
| 305 |
+
- Handles both file paths and numpy arrays
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
images: List or array of images in various formats:
|
| 309 |
+
- File paths (strings) pointing to image files
|
| 310 |
+
- PIL Images
|
| 311 |
+
- NumPy arrays (various shapes and value ranges)
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
np.ndarray: Preprocessed batch of shape (batch_size, 224, 224, 3)
|
| 315 |
+
with pixel values normalized to [0, 1] range
|
| 316 |
+
|
| 317 |
+
Note:
|
| 318 |
+
The method automatically handles different input formats and value ranges.
|
| 319 |
+
Images are resized using PIL's resize method with default interpolation.
|
| 320 |
+
|
| 321 |
+
Example:
|
| 322 |
+
>>> # Process file paths
|
| 323 |
+
>>> image_paths = ['image1.jpg', 'image2.png']
|
| 324 |
+
>>> processed = classifier.preprocess_image_batch(image_paths)
|
| 325 |
+
>>> print(processed.shape) # (2, 224, 224, 3)
|
| 326 |
+
|
| 327 |
+
>>> # Process numpy arrays
|
| 328 |
+
>>> image_arrays = [np.random.rand(100, 100, 3) for _ in range(5)]
|
| 329 |
+
>>> processed = classifier.preprocess_image_batch(image_arrays)
|
| 330 |
+
>>> print(processed.shape) # (5, 224, 224, 3)
|
| 331 |
+
"""
|
| 332 |
+
processed = []
|
| 333 |
+
|
| 334 |
+
for img in images:
|
| 335 |
+
if isinstance(img, str):
|
| 336 |
+
# Handle file paths
|
| 337 |
+
img = Image.open(img).convert('RGB')
|
| 338 |
+
img = img.resize((224, 224))
|
| 339 |
+
img_array = np.array(img) / 255.0
|
| 340 |
+
else:
|
| 341 |
+
# Handle numpy arrays
|
| 342 |
+
if img.shape[:2] != (224, 224):
|
| 343 |
+
# Resize if necessary
|
| 344 |
+
if img.max() <= 1:
|
| 345 |
+
img_pil = Image.fromarray((img * 255).astype('uint8'))
|
| 346 |
+
else:
|
| 347 |
+
img_pil = Image.fromarray(img.astype('uint8'))
|
| 348 |
+
img_pil = img_pil.resize((224, 224))
|
| 349 |
+
img_array = np.array(img_pil) / 255.0
|
| 350 |
+
else:
|
| 351 |
+
img_array = img.astype('float32')
|
| 352 |
+
if img_array.max() > 1:
|
| 353 |
+
img_array = img_array / 255.0
|
| 354 |
+
|
| 355 |
+
processed.append(img_array)
|
| 356 |
+
|
| 357 |
+
return np.array(processed)
|
| 358 |
+
|
| 359 |
+
def extract_embeddings(self, images, batch_size=16):
|
| 360 |
+
"""
|
| 361 |
+
Extract feature embeddings from images using Path Foundation model.
|
| 362 |
+
|
| 363 |
+
This method processes images in batches to extract high-level feature representations
|
| 364 |
+
using the pre-trained Path Foundation model. The embeddings capture semantic information
|
| 365 |
+
about the histopathology images that can be used for classification.
|
| 366 |
+
|
| 367 |
+
The method handles different model interface types and provides progress tracking
|
| 368 |
+
for large datasets. It automatically determines the embedding dimension on first use.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
images: Array of preprocessed images or list of image paths
|
| 372 |
+
batch_size (int): Number of images to process per batch. Smaller batches
|
| 373 |
+
use less memory but may be slower. Default: 16
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
np.ndarray: Extracted embeddings of shape (num_images, embedding_dim)
|
| 377 |
+
where embedding_dim is determined by the Path Foundation model
|
| 378 |
+
|
| 379 |
+
Raises:
|
| 380 |
+
ValueError: If no embeddings are successfully extracted
|
| 381 |
+
RuntimeError: If the Path Foundation model is not loaded
|
| 382 |
+
|
| 383 |
+
Note:
|
| 384 |
+
The embedding dimension is automatically determined from the first successful
|
| 385 |
+
batch and stored in self.embedding_dim for use in classifier construction.
|
| 386 |
+
|
| 387 |
+
Example:
|
| 388 |
+
>>> # Extract embeddings from a dataset
|
| 389 |
+
>>> embeddings = classifier.extract_embeddings(images, batch_size=32)
|
| 390 |
+
>>> print(f"Extracted {embeddings.shape[0]} embeddings of dimension {embeddings.shape[1]}")
|
| 391 |
+
|
| 392 |
+
>>> # Process with smaller batch size for memory-constrained environments
|
| 393 |
+
>>> embeddings = classifier.extract_embeddings(images, batch_size=8)
|
| 394 |
+
"""
|
| 395 |
+
print(f"Extracting embeddings from {len(images)} images...")
|
| 396 |
+
|
| 397 |
+
embeddings = []
|
| 398 |
+
num_batches = (len(images) + batch_size - 1) // batch_size
|
| 399 |
+
|
| 400 |
+
for i in range(0, len(images), batch_size):
|
| 401 |
+
batch = images[i:i + batch_size]
|
| 402 |
+
processed_batch = self.preprocess_image_batch(batch)
|
| 403 |
+
|
| 404 |
+
try:
|
| 405 |
+
# Handle different model interface types
|
| 406 |
+
if hasattr(self.path_foundation, 'signatures') and "serving_default" in self.path_foundation.signatures:
|
| 407 |
+
# TensorFlow SavedModel format
|
| 408 |
+
infer = self.path_foundation.signatures["serving_default"]
|
| 409 |
+
batch_embeddings = infer(tf.constant(processed_batch))
|
| 410 |
+
elif hasattr(self.path_foundation, 'predict'):
|
| 411 |
+
# Standard Keras model
|
| 412 |
+
batch_embeddings = self.path_foundation.predict(processed_batch, verbose=0)
|
| 413 |
+
else:
|
| 414 |
+
# Direct callable
|
| 415 |
+
batch_embeddings = self.path_foundation(processed_batch)
|
| 416 |
+
|
| 417 |
+
# Handle different output formats
|
| 418 |
+
if isinstance(batch_embeddings, dict):
|
| 419 |
+
key = list(batch_embeddings.keys())[0]
|
| 420 |
+
if hasattr(batch_embeddings[key], 'numpy'):
|
| 421 |
+
batch_embeddings = batch_embeddings[key].numpy()
|
| 422 |
+
else:
|
| 423 |
+
batch_embeddings = batch_embeddings[key]
|
| 424 |
+
elif hasattr(batch_embeddings, 'numpy'):
|
| 425 |
+
batch_embeddings = batch_embeddings.numpy()
|
| 426 |
+
|
| 427 |
+
embeddings.append(batch_embeddings)
|
| 428 |
+
|
| 429 |
+
# Progress reporting
|
| 430 |
+
batch_num = i // batch_size + 1
|
| 431 |
+
if batch_num % 10 == 0:
|
| 432 |
+
print(f"Processed batch {batch_num}/{num_batches}")
|
| 433 |
+
|
| 434 |
+
except Exception as e:
|
| 435 |
+
print(f"Error processing batch {batch_num}: {e}")
|
| 436 |
+
continue
|
| 437 |
+
|
| 438 |
+
if not embeddings:
|
| 439 |
+
raise ValueError("No embeddings extracted successfully")
|
| 440 |
+
|
| 441 |
+
final_embeddings = np.vstack(embeddings)
|
| 442 |
+
|
| 443 |
+
# Set embedding dimension for classifier head
|
| 444 |
+
if self.embedding_dim is None:
|
| 445 |
+
self.embedding_dim = final_embeddings.shape[1]
|
| 446 |
+
print(f"Embedding dimension: {self.embedding_dim}")
|
| 447 |
+
|
| 448 |
+
print(f"Final embeddings shape: {final_embeddings.shape}")
|
| 449 |
+
return final_embeddings
|
| 450 |
+
|
| 451 |
+
def build_classifier(self):
|
| 452 |
+
"""
|
| 453 |
+
Build the classification head architecture.
|
| 454 |
+
|
| 455 |
+
This method constructs the neural network architecture for breast cancer classification.
|
| 456 |
+
It creates different architectures based on the fine_tune setting:
|
| 457 |
+
|
| 458 |
+
1. End-to-end model (fine_tune=True): Input -> Path Foundation -> Classifier -> Output
|
| 459 |
+
2. Feature-based model (fine_tune=False): Embeddings -> Classifier -> Output
|
| 460 |
+
|
| 461 |
+
The architecture includes:
|
| 462 |
+
- Progressive dimensionality reduction (768 -> 384 -> 192 -> 2)
|
| 463 |
+
- L2 regularisation for weight decay and overfitting prevention
|
| 464 |
+
- Batch normalisation for training stability and faster convergence
|
| 465 |
+
- Dropout layers for regularization
|
| 466 |
+
- AdamW optimizer with appropriate learning rates
|
| 467 |
+
|
| 468 |
+
Returns:
|
| 469 |
+
None: The model is stored in self.model and compiled
|
| 470 |
+
|
| 471 |
+
Raises:
|
| 472 |
+
ValueError: If embedding dimension is not set (run extract_embeddings first)
|
| 473 |
+
|
| 474 |
+
Note:
|
| 475 |
+
The method automatically selects appropriate learning rates:
|
| 476 |
+
- Lower learning rate (1e-5) for fine-tuning to preserve pre-trained features
|
| 477 |
+
- Higher learning rate (0.001) for training from scratch on embeddings
|
| 478 |
+
|
| 479 |
+
Architecture Details:
|
| 480 |
+
- Input: Either raw images (224x224x3) or embeddings (embedding_dim,)
|
| 481 |
+
- Hidden layers: 768 -> 384 -> 192 neurons with ReLU activation
|
| 482 |
+
- Output: 2 neurons with softmax activation (benign/malignant)
|
| 483 |
+
- Regularisation: L2 weight decay (1e-4), Dropout (0.5, 0.3, 0.2)
|
| 484 |
+
- Normalisation: Batch normalisation after each dense layer
|
| 485 |
+
|
| 486 |
+
Example:
|
| 487 |
+
>>> classifier = BreastCancerClassifier(fine_tune=False)
|
| 488 |
+
>>> classifier.load_path_foundation()
|
| 489 |
+
>>> embeddings = classifier.extract_embeddings(images)
|
| 490 |
+
>>> classifier.build_classifier()
|
| 491 |
+
>>> print(f"Model has {classifier.model.count_params():,} parameters")
|
| 492 |
+
"""
|
| 493 |
+
if self.embedding_dim is None:
|
| 494 |
+
raise ValueError("Embedding dimension not set - run extract_embeddings first")
|
| 495 |
+
|
| 496 |
+
if self.fine_tune:
|
| 497 |
+
# End-to-end fine-tuning architecture
|
| 498 |
+
inputs = tf.keras.Input(shape=(224, 224, 3))
|
| 499 |
+
x = self.path_foundation(inputs)
|
| 500 |
+
|
| 501 |
+
# Classification head with regularization
|
| 502 |
+
x = tf.keras.layers.Dense(768, activation='relu',
|
| 503 |
+
kernel_regularizer=regularizers.l2(1e-4))(x)
|
| 504 |
+
x = tf.keras.layers.BatchNormalization()(x)
|
| 505 |
+
x = tf.keras.layers.Dropout(0.5)(x)
|
| 506 |
+
|
| 507 |
+
x = tf.keras.layers.Dense(384, activation='relu',
|
| 508 |
+
kernel_regularizer=regularizers.l2(1e-4))(x)
|
| 509 |
+
x = tf.keras.layers.BatchNormalization()(x)
|
| 510 |
+
x = tf.keras.layers.Dropout(0.3)(x)
|
| 511 |
+
|
| 512 |
+
x = tf.keras.layers.Dense(192, activation='relu',
|
| 513 |
+
kernel_regularizer=regularizers.l2(1e-4))(x)
|
| 514 |
+
x = tf.keras.layers.Dropout(0.2)(x)
|
| 515 |
+
|
| 516 |
+
outputs = tf.keras.layers.Dense(self.num_classes, activation='softmax')(x)
|
| 517 |
+
self.model = tf.keras.Model(inputs, outputs)
|
| 518 |
+
|
| 519 |
+
# Lower learning rate for fine-tuning to preserve pre-trained features
|
| 520 |
+
optimizer = tf.keras.optimizers.AdamW(learning_rate=1e-5, weight_decay=1e-5)
|
| 521 |
+
|
| 522 |
+
else:
|
| 523 |
+
# Feature extraction architecture (recommended approach)
|
| 524 |
+
self.model = tf.keras.Sequential([
|
| 525 |
+
tf.keras.layers.Input(shape=(self.embedding_dim,)),
|
| 526 |
+
|
| 527 |
+
# First dense block
|
| 528 |
+
tf.keras.layers.Dense(768, activation='relu',
|
| 529 |
+
kernel_regularizer=regularizers.l2(1e-4)),
|
| 530 |
+
tf.keras.layers.BatchNormalization(),
|
| 531 |
+
tf.keras.layers.Dropout(0.5),
|
| 532 |
+
|
| 533 |
+
# Second dense block
|
| 534 |
+
tf.keras.layers.Dense(384, activation='relu',
|
| 535 |
+
kernel_regularizer=regularizers.l2(1e-4)),
|
| 536 |
+
tf.keras.layers.BatchNormalization(),
|
| 537 |
+
tf.keras.layers.Dropout(0.3),
|
| 538 |
+
|
| 539 |
+
# Third dense block
|
| 540 |
+
tf.keras.layers.Dense(192, activation='relu',
|
| 541 |
+
kernel_regularizer=regularizers.l2(1e-4)),
|
| 542 |
+
tf.keras.layers.Dropout(0.2),
|
| 543 |
+
|
| 544 |
+
# Output layer
|
| 545 |
+
tf.keras.layers.Dense(self.num_classes, activation='softmax')
|
| 546 |
+
])
|
| 547 |
+
|
| 548 |
+
# Higher learning rate for training from scratch
|
| 549 |
+
optimizer = tf.keras.optimizers.AdamW(learning_rate=0.001, weight_decay=1e-5)
|
| 550 |
+
|
| 551 |
+
# Compile model with sparse categorical crossentropy for integer labels
|
| 552 |
+
self.model.compile(
|
| 553 |
+
optimizer=optimizer,
|
| 554 |
+
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
|
| 555 |
+
metrics=['accuracy']
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
print(f"Model architecture built - Fine-tuning: {self.fine_tune}")
|
| 559 |
+
print(f"Total parameters: {self.model.count_params():,}")
|
| 560 |
+
|
| 561 |
+
def train_model(self, X_train, y_train, X_val, y_val, epochs=50):
|
| 562 |
+
"""
|
| 563 |
+
Train the classification model with advanced techniques and callbacks.
|
| 564 |
+
|
| 565 |
+
This method implements a comprehensive training pipeline with:
|
| 566 |
+
- Class balancing to handle imbalanced datasets
|
| 567 |
+
- Early stopping to prevent overfitting
|
| 568 |
+
- Learning rate reduction on plateau
|
| 569 |
+
- Model checkpointing to save best weights
|
| 570 |
+
- Adaptive batch sizing based on training mode
|
| 571 |
+
|
| 572 |
+
Args:
|
| 573 |
+
X_train: Training features (embeddings or images)
|
| 574 |
+
y_train: Training labels (0 for benign, 1 for malignant)
|
| 575 |
+
X_val: Validation features
|
| 576 |
+
y_val: Validation labels
|
| 577 |
+
epochs (int): Maximum number of training epochs. Default: 50
|
| 578 |
+
|
| 579 |
+
Returns:
|
| 580 |
+
tf.keras.callbacks.History: Training history containing loss and accuracy curves
|
| 581 |
+
|
| 582 |
+
Note:
|
| 583 |
+
The method automatically handles class imbalance by computing balanced weights.
|
| 584 |
+
Training uses different batch sizes: 32 for fine-tuning, 64 for feature extraction.
|
| 585 |
+
|
| 586 |
+
Callbacks Used:
|
| 587 |
+
- EarlyStopping: Stops training if validation accuracy doesn't improve for 10 epochs
|
| 588 |
+
- ReduceLROnPlateau: Reduces learning rate by 50% if validation loss plateaus
|
| 589 |
+
- ModelCheckpoint: Saves the best model based on validation accuracy
|
| 590 |
+
|
| 591 |
+
Example:
|
| 592 |
+
>>> # Train the model
|
| 593 |
+
>>> history = classifier.train_model(X_train, y_train, X_val, y_val, epochs=30)
|
| 594 |
+
>>>
|
| 595 |
+
>>> # Access training metrics
|
| 596 |
+
>>> print(f"Final training accuracy: {history.history['accuracy'][-1]:.4f}")
|
| 597 |
+
>>> print(f"Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}")
|
| 598 |
+
"""
|
| 599 |
+
# Compute class weights to handle imbalanced datasets
|
| 600 |
+
try:
|
| 601 |
+
classes = np.unique(y_train)
|
| 602 |
+
weights = compute_class_weight(class_weight='balanced', classes=classes, y=y_train)
|
| 603 |
+
class_weight = {int(c): float(w) for c, w in zip(classes, weights)}
|
| 604 |
+
print(f"Class weights computed: {class_weight}")
|
| 605 |
+
except Exception:
|
| 606 |
+
class_weight = None
|
| 607 |
+
print("Using uniform class weights")
|
| 608 |
+
|
| 609 |
+
# Define training callbacks for robust training
|
| 610 |
+
callbacks = [
|
| 611 |
+
tf.keras.callbacks.EarlyStopping(
|
| 612 |
+
monitor='val_accuracy',
|
| 613 |
+
patience=10,
|
| 614 |
+
restore_best_weights=True,
|
| 615 |
+
verbose=1
|
| 616 |
+
),
|
| 617 |
+
tf.keras.callbacks.ReduceLROnPlateau(
|
| 618 |
+
monitor='val_loss',
|
| 619 |
+
factor=0.5,
|
| 620 |
+
patience=5,
|
| 621 |
+
min_lr=1e-7,
|
| 622 |
+
verbose=1
|
| 623 |
+
),
|
| 624 |
+
tf.keras.callbacks.ModelCheckpoint(
|
| 625 |
+
'best_model.keras',
|
| 626 |
+
monitor='val_accuracy',
|
| 627 |
+
save_best_only=True,
|
| 628 |
+
verbose=0
|
| 629 |
+
)
|
| 630 |
+
]
|
| 631 |
+
|
| 632 |
+
print("Starting model training...")
|
| 633 |
+
print(f"Training samples: {len(X_train)}, Validation samples: {len(X_val)}")
|
| 634 |
+
|
| 635 |
+
# Adaptive batch sizing based on training mode
|
| 636 |
+
batch_size = 32 if self.fine_tune else 64
|
| 637 |
+
print(f"Using batch size: {batch_size}")
|
| 638 |
+
|
| 639 |
+
# Train the model
|
| 640 |
+
self.history = self.model.fit(
|
| 641 |
+
X_train, y_train,
|
| 642 |
+
validation_data=(X_val, y_val),
|
| 643 |
+
epochs=epochs,
|
| 644 |
+
batch_size=batch_size,
|
| 645 |
+
callbacks=callbacks,
|
| 646 |
+
verbose=1,
|
| 647 |
+
class_weight=class_weight
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
print("Training completed successfully!")
|
| 651 |
+
return self.history
|
| 652 |
+
|
| 653 |
+
def evaluate_model(self, X_test, y_test):
|
| 654 |
+
"""
|
| 655 |
+
Comprehensive model evaluation with multiple performance metrics and visualisations.
|
| 656 |
+
|
| 657 |
+
This method provides a thorough evaluation of the trained model including:
|
| 658 |
+
- Accuracy, Precision, Recall, and F1-score calculations
|
| 659 |
+
- Detailed classification report with per-class metrics
|
| 660 |
+
- Confusion matrix visualisation and analysis
|
| 661 |
+
- Model predictions and probabilities for further analysis
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
X_test: Test features (embeddings or images)
|
| 665 |
+
y_test: True test labels (0 for benign, 1 for malignant)
|
| 666 |
+
|
| 667 |
+
Returns:
|
| 668 |
+
dict: Dictionary containing comprehensive evaluation results:
|
| 669 |
+
- 'accuracy': Overall accuracy score
|
| 670 |
+
- 'precision': Weighted average precision
|
| 671 |
+
- 'recall': Weighted average recall
|
| 672 |
+
- 'f1': Weighted average F1-score
|
| 673 |
+
- 'predictions': Predicted class labels
|
| 674 |
+
- 'probabilities': Prediction probabilities for each class
|
| 675 |
+
- 'confusion_matrix': 2x2 confusion matrix
|
| 676 |
+
|
| 677 |
+
Note:
|
| 678 |
+
The method generates and saves a confusion matrix plot as 'confusion_matrix.png'
|
| 679 |
+
and displays it using matplotlib. The plot uses a blue color scheme for clarity.
|
| 680 |
+
|
| 681 |
+
Metrics Explanation:
|
| 682 |
+
- Accuracy: Overall correctness of predictions
|
| 683 |
+
- Precision: True positives / (True positives + False positives)
|
| 684 |
+
- Recall: True positives / (True positives + False negatives)
|
| 685 |
+
- F1-score: Harmonic mean of precision and recall
|
| 686 |
+
|
| 687 |
+
Example:
|
| 688 |
+
>>> # Evaluate the trained model
|
| 689 |
+
>>> results = classifier.evaluate_model(X_test, y_test)
|
| 690 |
+
>>>
|
| 691 |
+
>>> # Access specific metrics
|
| 692 |
+
>>> print(f"Test Accuracy: {results['accuracy']:.4f}")
|
| 693 |
+
>>> print(f"F1-Score: {results['f1']:.4f}")
|
| 694 |
+
>>>
|
| 695 |
+
>>> # Analyze predictions
|
| 696 |
+
>>> predictions = results['predictions']
|
| 697 |
+
>>> probabilities = results['probabilities']
|
| 698 |
+
"""
|
| 699 |
+
print("Evaluating model performance...")
|
| 700 |
+
|
| 701 |
+
# Generate predictions and probabilities
|
| 702 |
+
y_pred_proba = self.model.predict(X_test)
|
| 703 |
+
y_pred = np.argmax(y_pred_proba, axis=1)
|
| 704 |
+
|
| 705 |
+
# Calculate comprehensive metrics
|
| 706 |
+
accuracy = accuracy_score(y_test, y_pred)
|
| 707 |
+
precision = precision_score(y_test, y_pred, average='weighted')
|
| 708 |
+
recall = recall_score(y_test, y_pred, average='weighted')
|
| 709 |
+
f1 = f1_score(y_test, y_pred, average='weighted')
|
| 710 |
+
|
| 711 |
+
# Display results
|
| 712 |
+
print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
|
| 713 |
+
print(f"Precision: {precision:.4f}")
|
| 714 |
+
print(f"Recall: {recall:.4f}")
|
| 715 |
+
print(f"F1-Score: {f1:.4f}")
|
| 716 |
+
|
| 717 |
+
# Detailed classification report
|
| 718 |
+
class_names = ['Benign', 'Malignant']
|
| 719 |
+
print("\nDetailed Classification Report:")
|
| 720 |
+
print(classification_report(y_test, y_pred, target_names=class_names))
|
| 721 |
+
|
| 722 |
+
# Generate and display confusion matrix
|
| 723 |
+
cm = confusion_matrix(y_test, y_pred)
|
| 724 |
+
|
| 725 |
+
# Create confusion matrix visualization
|
| 726 |
+
plt.figure(figsize=(8, 6))
|
| 727 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
| 728 |
+
xticklabels=class_names, yticklabels=class_names)
|
| 729 |
+
plt.title('Confusion Matrix - Breast Cancer Classification')
|
| 730 |
+
plt.xlabel('Predicted Label')
|
| 731 |
+
plt.ylabel('True Label')
|
| 732 |
+
plt.tight_layout()
|
| 733 |
+
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
|
| 734 |
+
# Close the figure to free resources and avoid blocking
|
| 735 |
+
plt.close()
|
| 736 |
+
|
| 737 |
+
# Print confusion matrix in text format
|
| 738 |
+
print("\nConfusion Matrix:")
|
| 739 |
+
print(f" Predicted")
|
| 740 |
+
print(f" {class_names[0]:>8} {class_names[1]:>8}")
|
| 741 |
+
print(f"Actual {class_names[0]:>6} {cm[0,0]:>8} {cm[0,1]:>8}")
|
| 742 |
+
print(f" {class_names[1]:>6} {cm[1,0]:>8} {cm[1,1]:>8}")
|
| 743 |
+
|
| 744 |
+
return {
|
| 745 |
+
'accuracy': accuracy,
|
| 746 |
+
'precision': precision,
|
| 747 |
+
'recall': recall,
|
| 748 |
+
'f1': f1,
|
| 749 |
+
'predictions': y_pred,
|
| 750 |
+
'probabilities': y_pred_proba,
|
| 751 |
+
'confusion_matrix': cm
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
def load_breakhis_data(data_dir="datasets/breakhis/histology_slides/breast", max_samples_per_class=2000, magnification="40X"):
|
| 755 |
+
"""
|
| 756 |
+
Load and preprocess the BreakHis breast cancer histopathology dataset.
|
| 757 |
+
|
| 758 |
+
The BreakHis dataset contains microscopic images of breast tumor tissue
|
| 759 |
+
collected from clinical studies. Images are organized by:
|
| 760 |
+
- Tumor type (benign/malignant)
|
| 761 |
+
- Specific histological type (adenosis, fibroadenoma, etc.)
|
| 762 |
+
- Patient ID
|
| 763 |
+
- Magnification level (40X, 100X, 200X, 400X)
|
| 764 |
+
|
| 765 |
+
This function loads images from the specified magnification level and
|
| 766 |
+
preprocesses them for use with the Path Foundation model.
|
| 767 |
+
|
| 768 |
+
Args:
|
| 769 |
+
data_dir (str): Path to BreakHis dataset root directory. Default structure:
|
| 770 |
+
datasets/breakhis/histology_slides/breast/
|
| 771 |
+
max_samples_per_class (int): Maximum images to load per class (benign/malignant).
|
| 772 |
+
Helps manage memory usage for large datasets.
|
| 773 |
+
magnification (str): Magnification level to use. Options: "40X", "100X", "200X", "400X".
|
| 774 |
+
Higher magnifications provide more detail but larger file sizes.
|
| 775 |
+
|
| 776 |
+
Returns:
|
| 777 |
+
tuple: (images, labels) as numpy arrays
|
| 778 |
+
- images: Array of shape (num_images, 224, 224, 3) with normalized pixel values
|
| 779 |
+
- labels: Array of shape (num_images,) with 0 for benign, 1 for malignant
|
| 780 |
+
|
| 781 |
+
Dataset Structure:
|
| 782 |
+
The function expects the following directory structure:
|
| 783 |
+
data_dir/
|
| 784 |
+
├── benign/SOB/
|
| 785 |
+
│ ├── adenosis/
|
| 786 |
+
│ ├── fibroadenoma/
|
| 787 |
+
│ ├── phyllodes_tumor/
|
| 788 |
+
│ └── tubular_adenoma/
|
| 789 |
+
└── malignant/SOB/
|
| 790 |
+
├── ductal_carcinoma/
|
| 791 |
+
├── lobular_carcinoma/
|
| 792 |
+
├── mucinous_carcinoma/
|
| 793 |
+
└── papillary_carcinoma/
|
| 794 |
+
|
| 795 |
+
Note:
|
| 796 |
+
Images are automatically resized to 224x224 pixels and normalized to [0,1] range.
|
| 797 |
+
The function handles various image formats (PNG, JPG, JPEG, TIF, TIFF).
|
| 798 |
+
|
| 799 |
+
Example:
|
| 800 |
+
>>> # Load BreakHis dataset with 40X magnification
|
| 801 |
+
>>> images, labels = load_breakhis_data(
|
| 802 |
+
... data_dir="datasets/breakhis/histology_slides/breast",
|
| 803 |
+
... max_samples_per_class=1000,
|
| 804 |
+
... magnification="40X"
|
| 805 |
+
... )
|
| 806 |
+
>>> print(f"Loaded {len(images)} images")
|
| 807 |
+
>>> print(f"Benign: {np.sum(labels == 0)}, Malignant: {np.sum(labels == 1)}")
|
| 808 |
+
"""
|
| 809 |
+
print(f"Loading BreakHis dataset (magnification: {magnification})...")
|
| 810 |
+
|
| 811 |
+
benign_dir = os.path.join(data_dir, "benign", "SOB")
|
| 812 |
+
malignant_dir = os.path.join(data_dir, "malignant", "SOB")
|
| 813 |
+
|
| 814 |
+
images = []
|
| 815 |
+
labels = []
|
| 816 |
+
|
| 817 |
+
def load_images_from_category(base_dir, label, max_count):
|
| 818 |
+
"""
|
| 819 |
+
Helper function to load images from a specific category (benign/malignant).
|
| 820 |
+
|
| 821 |
+
Traverses the directory structure: base_dir/tumor_type/patient_id/magnification/images
|
| 822 |
+
and loads images with progress reporting.
|
| 823 |
+
"""
|
| 824 |
+
if not os.path.exists(base_dir):
|
| 825 |
+
print(f"Warning: Directory {base_dir} not found")
|
| 826 |
+
return 0
|
| 827 |
+
|
| 828 |
+
count = 0
|
| 829 |
+
|
| 830 |
+
# Traverse: base_dir/tumor_type/patient_id/magnification/images
|
| 831 |
+
for tumor_type in os.listdir(base_dir):
|
| 832 |
+
tumor_dir = os.path.join(base_dir, tumor_type)
|
| 833 |
+
if not os.path.isdir(tumor_dir):
|
| 834 |
+
continue
|
| 835 |
+
|
| 836 |
+
for patient_id in os.listdir(tumor_dir):
|
| 837 |
+
patient_dir = os.path.join(tumor_dir, patient_id)
|
| 838 |
+
if not os.path.isdir(patient_dir):
|
| 839 |
+
continue
|
| 840 |
+
|
| 841 |
+
mag_dir = os.path.join(patient_dir, magnification)
|
| 842 |
+
if not os.path.exists(mag_dir):
|
| 843 |
+
continue
|
| 844 |
+
|
| 845 |
+
for filename in os.listdir(mag_dir):
|
| 846 |
+
if count >= max_count:
|
| 847 |
+
return count
|
| 848 |
+
|
| 849 |
+
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff')):
|
| 850 |
+
try:
|
| 851 |
+
img_path = os.path.join(mag_dir, filename)
|
| 852 |
+
img = Image.open(img_path).convert('RGB')
|
| 853 |
+
img = img.resize((224, 224))
|
| 854 |
+
img_array = np.array(img).astype('float32') / 255.0
|
| 855 |
+
images.append(img_array)
|
| 856 |
+
labels.append(label)
|
| 857 |
+
count += 1
|
| 858 |
+
|
| 859 |
+
if count % 100 == 0:
|
| 860 |
+
category = 'benign' if label == 0 else 'malignant'
|
| 861 |
+
print(f"Loaded {count} {category} images")
|
| 862 |
+
|
| 863 |
+
except Exception as e:
|
| 864 |
+
print(f"Error loading {filename}: {e}")
|
| 865 |
+
continue
|
| 866 |
+
return count
|
| 867 |
+
|
| 868 |
+
# Load both categories
|
| 869 |
+
benign_count = load_images_from_category(benign_dir, 0, max_samples_per_class)
|
| 870 |
+
malignant_count = load_images_from_category(malignant_dir, 1, max_samples_per_class)
|
| 871 |
+
|
| 872 |
+
print(f"BreakHis dataset loaded: {benign_count} benign, {malignant_count} malignant images")
|
| 873 |
+
|
| 874 |
+
return np.array(images), np.array(labels)
|
| 875 |
+
|
| 876 |
+
def load_pcam_data(data_dir="datasets/pcam", label_dir="datasets/Labels", max_samples=3000, augment=True):
|
| 877 |
+
"""
|
| 878 |
+
Load and preprocess the PatchCamelyon (PCam) dataset.
|
| 879 |
+
|
| 880 |
+
PCam contains 96x96 pixel patches extracted from histopathologic scans
|
| 881 |
+
of lymph node sections. Each patch is labeled with the presence of
|
| 882 |
+
metastatic tissue. This function includes data augmentation capabilities
|
| 883 |
+
to improve model generalization.
|
| 884 |
+
|
| 885 |
+
The dataset is stored in HDF5 format with separate files for images and labels,
|
| 886 |
+
and comes pre-split into training, validation, and test sets.
|
| 887 |
+
|
| 888 |
+
Args:
|
| 889 |
+
data_dir (str): Path to PCam image data directory containing:
|
| 890 |
+
- training_split.h5
|
| 891 |
+
- validation_split.h5
|
| 892 |
+
- test_split.h5
|
| 893 |
+
label_dir (str): Path to PCam label files directory containing:
|
| 894 |
+
- camelyonpatch_level_2_split_train_y.h5
|
| 895 |
+
- camelyonpatch_level_2_split_valid_y.h5
|
| 896 |
+
- camelyonpatch_level_2_split_test_y.h5
|
| 897 |
+
max_samples (int): Maximum total samples to load across all splits.
|
| 898 |
+
Distributed as: train=50%, val=25%, test=25%
|
| 899 |
+
augment (bool): Whether to apply data augmentation to training set.
|
| 900 |
+
Augmentation includes: horizontal flip, rotation, brightness adjustment
|
| 901 |
+
|
| 902 |
+
Returns:
|
| 903 |
+
dict: Dictionary with 'train', 'valid', 'test' keys containing (images, labels) tuples
|
| 904 |
+
- 'train': (train_images, train_labels) - Training data with optional augmentation
|
| 905 |
+
- 'valid': (val_images, val_labels) - Validation data
|
| 906 |
+
- 'test': (test_images, test_labels) - Test data
|
| 907 |
+
|
| 908 |
+
Dataset Details:
|
| 909 |
+
- Original patch size: 96x96 pixels
|
| 910 |
+
- Resized to: 224x224 pixels for Path Foundation compatibility
|
| 911 |
+
- Labels: 0 (normal tissue), 1 (metastatic tissue)
|
| 912 |
+
- Format: HDF5 files with 'x' key for images, 'y' key for labels
|
| 913 |
+
|
| 914 |
+
Data Augmentation (if enabled):
|
| 915 |
+
- Horizontal flip: 50% probability
|
| 916 |
+
- Rotation: Random 0°, 90°, 180°, or 270° rotation
|
| 917 |
+
- Brightness adjustment: 30% probability, factor between 0.9-1.1
|
| 918 |
+
|
| 919 |
+
Note:
|
| 920 |
+
The function automatically handles HDF5 file loading and memory management.
|
| 921 |
+
Images are resized from 96x96 to 224x224 pixels and normalized to [0,1] range.
|
| 922 |
+
|
| 923 |
+
Example:
|
| 924 |
+
>>> # Load PCam dataset with augmentation
|
| 925 |
+
>>> pcam_data = load_pcam_data(
|
| 926 |
+
... data_dir="datasets/pcam",
|
| 927 |
+
... label_dir="datasets/Labels",
|
| 928 |
+
... max_samples=2000,
|
| 929 |
+
... augment=True
|
| 930 |
+
... )
|
| 931 |
+
>>>
|
| 932 |
+
>>> # Access training data
|
| 933 |
+
>>> train_images, train_labels = pcam_data['train']
|
| 934 |
+
>>> print(f"Training samples: {len(train_images)}")
|
| 935 |
+
>>> print(f"Image shape: {train_images[0].shape}")
|
| 936 |
+
"""
|
| 937 |
+
print("Loading PatchCamelyon (PCam) dataset...")
|
| 938 |
+
|
| 939 |
+
# Define file paths
|
| 940 |
+
train_file = os.path.join(data_dir, "training_split.h5")
|
| 941 |
+
val_file = os.path.join(data_dir, "validation_split.h5")
|
| 942 |
+
test_file = os.path.join(data_dir, "test_split.h5")
|
| 943 |
+
train_label_file = os.path.join(label_dir, "camelyonpatch_level_2_split_train_y.h5")
|
| 944 |
+
val_label_file = os.path.join(label_dir, "camelyonpatch_level_2_split_valid_y.h5")
|
| 945 |
+
test_label_file = os.path.join(label_dir, "camelyonpatch_level_2_split_test_y.h5")
|
| 946 |
+
|
| 947 |
+
def preprocess(images):
|
| 948 |
+
"""Resize and normalize images from 96x96 to 224x224 pixels."""
|
| 949 |
+
processed = []
|
| 950 |
+
for img in images:
|
| 951 |
+
im = Image.fromarray(img)
|
| 952 |
+
im = im.resize((224, 224)) # Resize to match Path Foundation input
|
| 953 |
+
arr = np.array(im).astype('float32') / 255.0
|
| 954 |
+
processed.append(arr)
|
| 955 |
+
return np.array(processed)
|
| 956 |
+
|
| 957 |
+
def safe_load(img_file, label_file, limit):
|
| 958 |
+
"""Safely load data from HDF5 files with memory management."""
|
| 959 |
+
with h5py.File(img_file, 'r') as f_img, h5py.File(label_file, 'r') as f_lbl:
|
| 960 |
+
x = f_img['x'][:limit]
|
| 961 |
+
y = f_lbl['y'][:limit]
|
| 962 |
+
y = y.reshape(-1) # Ensure 1D label array
|
| 963 |
+
return x, y
|
| 964 |
+
|
| 965 |
+
# Load data splits with sample limits
|
| 966 |
+
train_images, train_labels = safe_load(train_file, train_label_file, max_samples//2)
|
| 967 |
+
val_images, val_labels = safe_load(val_file, val_label_file, max_samples//4)
|
| 968 |
+
test_images, test_labels = safe_load(test_file, test_label_file, max_samples//4)
|
| 969 |
+
|
| 970 |
+
# Preprocess all splits
|
| 971 |
+
train_images = preprocess(train_images)
|
| 972 |
+
val_images = preprocess(val_images)
|
| 973 |
+
test_images = preprocess(test_images)
|
| 974 |
+
|
| 975 |
+
# Apply data augmentation to training set
|
| 976 |
+
if augment:
|
| 977 |
+
print("Applying data augmentation to training set...")
|
| 978 |
+
for i in range(len(train_images)):
|
| 979 |
+
# Random horizontal flip
|
| 980 |
+
if np.random.rand() > 0.5:
|
| 981 |
+
train_images[i] = np.fliplr(train_images[i])
|
| 982 |
+
|
| 983 |
+
# Random rotation (0, 90, 180, 270 degrees)
|
| 984 |
+
k = np.random.randint(0, 4)
|
| 985 |
+
if k:
|
| 986 |
+
train_images[i] = np.rot90(train_images[i], k)
|
| 987 |
+
|
| 988 |
+
# Random brightness adjustment
|
| 989 |
+
if np.random.rand() > 0.7:
|
| 990 |
+
im = Image.fromarray((train_images[i] * 255).astype('uint8'))
|
| 991 |
+
brightness_factor = 0.9 + 0.2 * np.random.rand()
|
| 992 |
+
im = Image.fromarray(
|
| 993 |
+
np.clip(np.array(im, dtype=np.float32) * brightness_factor, 0, 255).astype('uint8')
|
| 994 |
+
)
|
| 995 |
+
train_images[i] = np.array(im).astype('float32') / 255.0
|
| 996 |
+
|
| 997 |
+
print(f"PCam dataset loaded - Train: {len(train_images)}, Val: {len(val_images)}, Test: {len(test_images)}")
|
| 998 |
+
|
| 999 |
+
return {
|
| 1000 |
+
'train': (train_images, train_labels),
|
| 1001 |
+
'valid': (val_images, val_labels),
|
| 1002 |
+
'test': (test_images, test_labels)
|
| 1003 |
+
}
|
| 1004 |
+
|
| 1005 |
+
def load_bach_data(data_dir="datasets/BACH/ICIAR2018_BACH_Challenge/Photos", max_samples=400, augment=True):
|
| 1006 |
+
"""
|
| 1007 |
+
Load and preprocess the BACH (ICIAR 2018) breast cancer histology dataset.
|
| 1008 |
+
|
| 1009 |
+
BACH contains microscopy images classified into four categories:
|
| 1010 |
+
- Normal tissue
|
| 1011 |
+
- Benign lesions
|
| 1012 |
+
- In situ carcinoma
|
| 1013 |
+
- Invasive carcinoma
|
| 1014 |
+
|
| 1015 |
+
For binary classification, this function maps:
|
| 1016 |
+
- Normal + Benign → Benign (label 0)
|
| 1017 |
+
- In situ + Invasive → Malignant (label 1)
|
| 1018 |
+
|
| 1019 |
+
Args:
|
| 1020 |
+
data_dir (str): Path to BACH dataset directory containing class subdirectories:
|
| 1021 |
+
- Normal/
|
| 1022 |
+
- Benign/
|
| 1023 |
+
- InSitu/
|
| 1024 |
+
- Invasive/
|
| 1025 |
+
max_samples (int): Maximum total samples to load across all classes.
|
| 1026 |
+
Distributed evenly across the 4 classes.
|
| 1027 |
+
augment (bool): Whether to apply data augmentation (currently not implemented
|
| 1028 |
+
for BACH dataset but parameter kept for consistency)
|
| 1029 |
+
|
| 1030 |
+
Returns:
|
| 1031 |
+
dict: Dictionary with 'train', 'valid', 'test' keys containing (images, labels) tuples
|
| 1032 |
+
- 'train': (train_images, train_labels) - Training data
|
| 1033 |
+
- 'valid': (val_images, val_labels) - Validation data
|
| 1034 |
+
- 'test': (test_images, test_labels) - Test data
|
| 1035 |
+
|
| 1036 |
+
Dataset Details:
|
| 1037 |
+
- Original categories: 4 classes (Normal, Benign, InSitu, Invasive)
|
| 1038 |
+
- Binary mapping: Normal(0), Benign(1) → Benign(0); InSitu(2), Invasive(3) → Malignant(1)
|
| 1039 |
+
- Image format: TIF, TIFF, PNG, JPG, JPEG
|
| 1040 |
+
- Resized to: 224x224 pixels for Path Foundation compatibility
|
| 1041 |
+
- Normalized to: [0, 1] range
|
| 1042 |
+
|
| 1043 |
+
Data Splitting:
|
| 1044 |
+
- Test set: 20% of total data
|
| 1045 |
+
- Training set: 60% of total data (75% of remaining after test split)
|
| 1046 |
+
- Validation set: 20% of total data (25% of remaining after test split)
|
| 1047 |
+
- Stratified splitting to maintain class distribution
|
| 1048 |
+
|
| 1049 |
+
Note:
|
| 1050 |
+
The function automatically handles the 4-class to binary classification mapping.
|
| 1051 |
+
Images are resized to 224x224 pixels and normalized to [0,1] range.
|
| 1052 |
+
The augment parameter is kept for API consistency but augmentation is not
|
| 1053 |
+
currently implemented for the BACH dataset.
|
| 1054 |
+
|
| 1055 |
+
Example:
|
| 1056 |
+
>>> # Load BACH dataset
|
| 1057 |
+
>>> bach_data = load_bach_data(
|
| 1058 |
+
... data_dir="datasets/BACH/ICIAR2018_BACH_Challenge/Photos",
|
| 1059 |
+
... max_samples=400,
|
| 1060 |
+
... augment=True
|
| 1061 |
+
... )
|
| 1062 |
+
>>>
|
| 1063 |
+
>>> # Access training data
|
| 1064 |
+
>>> train_images, train_labels = bach_data['train']
|
| 1065 |
+
>>> print(f"Training samples: {len(train_images)}")
|
| 1066 |
+
>>> print(f"Class distribution: Benign={np.sum(train_labels==0)}, Malignant={np.sum(train_labels==1)}")
|
| 1067 |
+
"""
|
| 1068 |
+
print("Loading BACH (ICIAR 2018) dataset...")
|
| 1069 |
+
|
| 1070 |
+
# Original BACH categories mapped to binary classification
|
| 1071 |
+
class_dirs = {
|
| 1072 |
+
'Normal': 0, # Normal tissue → Benign
|
| 1073 |
+
'Benign': 1, # Benign lesions → Benign
|
| 1074 |
+
'InSitu': 2, # In situ carcinoma → Malignant
|
| 1075 |
+
'Invasive': 3, # Invasive carcinoma → Malignant
|
| 1076 |
+
}
|
| 1077 |
+
|
| 1078 |
+
images = []
|
| 1079 |
+
labels = []
|
| 1080 |
+
per_class_limit = None if not max_samples else max_samples // 4
|
| 1081 |
+
counters = {0: 0, 1: 0, 2: 0, 3: 0}
|
| 1082 |
+
|
| 1083 |
+
# Load images from each category
|
| 1084 |
+
for cls_name, cls_label in class_dirs.items():
|
| 1085 |
+
cls_path = os.path.join(data_dir, cls_name)
|
| 1086 |
+
if not os.path.isdir(cls_path):
|
| 1087 |
+
print(f"Warning: Directory {cls_path} not found")
|
| 1088 |
+
continue
|
| 1089 |
+
|
| 1090 |
+
for fname in os.listdir(cls_path):
|
| 1091 |
+
if per_class_limit and counters[cls_label] >= per_class_limit:
|
| 1092 |
+
break
|
| 1093 |
+
if not fname.lower().endswith((".tif", ".tiff", ".png", ".jpg", ".jpeg")):
|
| 1094 |
+
continue
|
| 1095 |
+
|
| 1096 |
+
fpath = os.path.join(cls_path, fname)
|
| 1097 |
+
try:
|
| 1098 |
+
im = Image.open(fpath).convert('RGB')
|
| 1099 |
+
im = im.resize((224, 224))
|
| 1100 |
+
arr = np.array(im).astype('float32') / 255.0
|
| 1101 |
+
images.append(arr)
|
| 1102 |
+
labels.append(cls_label)
|
| 1103 |
+
counters[cls_label] += 1
|
| 1104 |
+
except Exception as e:
|
| 1105 |
+
print(f"Error loading {fname}: {e}")
|
| 1106 |
+
continue
|
| 1107 |
+
|
| 1108 |
+
images = np.array(images)
|
| 1109 |
+
labels = np.array(labels)
|
| 1110 |
+
|
| 1111 |
+
# Convert 4-class to binary classification
|
| 1112 |
+
if labels.size > 0:
|
| 1113 |
+
# Map: Normal(0), Benign(1) → Benign(0); InSitu(2), Invasive(3) → Malignant(1)
|
| 1114 |
+
labels = np.where(np.isin(labels, [0, 1]), 0, 1)
|
| 1115 |
+
|
| 1116 |
+
print(f"BACH dataset loaded: {len(images)} images")
|
| 1117 |
+
print(f"Class distribution - Benign: {np.sum(labels == 0)}, Malignant: {np.sum(labels == 1)}")
|
| 1118 |
+
|
| 1119 |
+
# Split into train/validation/test sets
|
| 1120 |
+
X_temp, X_test, y_temp, y_test = train_test_split(
|
| 1121 |
+
images, labels, test_size=0.2,
|
| 1122 |
+
stratify=labels if len(set(labels)) > 1 else None,
|
| 1123 |
+
random_state=42
|
| 1124 |
+
)
|
| 1125 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 1126 |
+
X_temp, y_temp, test_size=0.25,
|
| 1127 |
+
stratify=y_temp if len(set(y_temp)) > 1 else None,
|
| 1128 |
+
random_state=42
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
return {
|
| 1132 |
+
'train': (X_train, y_train),
|
| 1133 |
+
'valid': (X_val, y_val),
|
| 1134 |
+
'test': (X_test, y_test)
|
| 1135 |
+
}
|
| 1136 |
+
|
| 1137 |
+
def load_combined_data(dataset_choice="breakhis", max_samples=5000):
|
| 1138 |
+
"""
|
| 1139 |
+
Unified data loading function supporting multiple datasets and combinations.
|
| 1140 |
+
|
| 1141 |
+
This function serves as the main entry point for data loading, supporting:
|
| 1142 |
+
- Individual datasets (BreakHis, PCam, BACH)
|
| 1143 |
+
- Combined dataset training for improved generalization
|
| 1144 |
+
- Consistent data splitting and preprocessing across all datasets
|
| 1145 |
+
|
| 1146 |
+
The combined dataset approach leverages multiple histopathology datasets to
|
| 1147 |
+
create a more robust and generalizable model by training on diverse data sources.
|
| 1148 |
+
|
| 1149 |
+
Args:
|
| 1150 |
+
dataset_choice (str): Dataset to load. Options:
|
| 1151 |
+
- "breakhis": BreakHis breast cancer histopathology dataset
|
| 1152 |
+
- "pcam": PatchCamelyon lymph node metastasis dataset
|
| 1153 |
+
- "bach": BACH ICIAR 2018 breast cancer histology dataset
|
| 1154 |
+
- "combined": Ensemble of all three datasets for robust training
|
| 1155 |
+
max_samples (int): Maximum total samples to load. For individual datasets,
|
| 1156 |
+
this is the total limit. For combined datasets, this is
|
| 1157 |
+
distributed across the constituent datasets.
|
| 1158 |
+
|
| 1159 |
+
Returns:
|
| 1160 |
+
dict: Dictionary with 'train', 'valid', 'test' keys containing (images, labels) tuples
|
| 1161 |
+
- 'train': (train_images, train_labels) - Training data
|
| 1162 |
+
- 'valid': (val_images, val_labels) - Validation data
|
| 1163 |
+
- 'test': (test_images, test_labels) - Test data
|
| 1164 |
+
|
| 1165 |
+
Dataset Combinations:
|
| 1166 |
+
When dataset_choice="combined", the function:
|
| 1167 |
+
1. Loads BreakHis, PCam, and BACH datasets
|
| 1168 |
+
2. Combines their training data
|
| 1169 |
+
3. Shuffles the combined dataset
|
| 1170 |
+
4. Splits into train/validation/test sets
|
| 1171 |
+
5. Maintains class balance through stratified splitting
|
| 1172 |
+
|
| 1173 |
+
Sample Distribution (for combined datasets):
|
| 1174 |
+
- BreakHis: max_samples // 6 (per-class limit)
|
| 1175 |
+
- PCam: max_samples // 3 (total limit)
|
| 1176 |
+
- BACH: max_samples // 3 (total limit)
|
| 1177 |
+
|
| 1178 |
+
Data Splitting:
|
| 1179 |
+
- Test set: 20% of total data
|
| 1180 |
+
- Training set: 60% of total data (75% of remaining after test split)
|
| 1181 |
+
- Validation set: 20% of total data (25% of remaining after test split)
|
| 1182 |
+
- Stratified splitting to maintain class distribution
|
| 1183 |
+
|
| 1184 |
+
Note:
|
| 1185 |
+
All datasets are automatically preprocessed to 224x224 pixels and normalized
|
| 1186 |
+
to [0,1] range for compatibility with the Path Foundation model.
|
| 1187 |
+
|
| 1188 |
+
Example:
|
| 1189 |
+
>>> # Load individual dataset
|
| 1190 |
+
>>> data = load_combined_data("breakhis", max_samples=2000)
|
| 1191 |
+
>>>
|
| 1192 |
+
>>> # Load combined dataset for robust training
|
| 1193 |
+
>>> combined_data = load_combined_data("combined", max_samples=6000)
|
| 1194 |
+
>>>
|
| 1195 |
+
>>> # Access training data
|
| 1196 |
+
>>> train_images, train_labels = combined_data['train']
|
| 1197 |
+
>>> print(f"Combined training samples: {len(train_images)}")
|
| 1198 |
+
"""
|
| 1199 |
+
|
| 1200 |
+
if dataset_choice.lower() == "breakhis":
|
| 1201 |
+
print("Loading BreakHis dataset only...")
|
| 1202 |
+
images, labels = load_breakhis_data(max_samples_per_class=max_samples//2)
|
| 1203 |
+
|
| 1204 |
+
# Split into train/validation/test
|
| 1205 |
+
X_temp, X_test, y_temp, y_test = train_test_split(
|
| 1206 |
+
images, labels, test_size=0.2, stratify=labels, random_state=42
|
| 1207 |
+
)
|
| 1208 |
+
|
| 1209 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 1210 |
+
X_temp, y_temp, test_size=0.25, stratify=y_temp, random_state=42
|
| 1211 |
+
)
|
| 1212 |
+
|
| 1213 |
+
return {
|
| 1214 |
+
'train': (X_train, y_train),
|
| 1215 |
+
'valid': (X_val, y_val),
|
| 1216 |
+
'test': (X_test, y_test)
|
| 1217 |
+
}
|
| 1218 |
+
|
| 1219 |
+
elif dataset_choice.lower() == "pcam":
|
| 1220 |
+
return load_pcam_data(max_samples=max_samples)
|
| 1221 |
+
|
| 1222 |
+
elif dataset_choice.lower() == "bach":
|
| 1223 |
+
return load_bach_data(max_samples=max_samples)
|
| 1224 |
+
|
| 1225 |
+
elif dataset_choice.lower() == "combined":
|
| 1226 |
+
print("Loading combined datasets for enhanced generalization...")
|
| 1227 |
+
|
| 1228 |
+
# Distribute samples across datasets
|
| 1229 |
+
if max_samples is None:
|
| 1230 |
+
per_bh = None
|
| 1231 |
+
per_pc = None
|
| 1232 |
+
per_ba = None
|
| 1233 |
+
else:
|
| 1234 |
+
per_dataset = max(1, max_samples // 3)
|
| 1235 |
+
per_bh = per_dataset // 2 # BreakHis uses per-class limit
|
| 1236 |
+
per_pc = per_dataset
|
| 1237 |
+
per_ba = per_dataset
|
| 1238 |
+
|
| 1239 |
+
# Load individual datasets
|
| 1240 |
+
print("Loading BreakHis component...")
|
| 1241 |
+
bh_images, bh_labels = load_breakhis_data(
|
| 1242 |
+
max_samples_per_class=per_bh if per_bh else 10**9
|
| 1243 |
+
)
|
| 1244 |
+
|
| 1245 |
+
print("Loading PCam component...")
|
| 1246 |
+
pcam = load_pcam_data(max_samples=per_pc, augment=True)
|
| 1247 |
+
pc_train_images, pc_train_labels = pcam["train"]
|
| 1248 |
+
|
| 1249 |
+
print("Loading BACH component...")
|
| 1250 |
+
bach = load_bach_data(max_samples=per_ba, augment=True)
|
| 1251 |
+
b_train_images, b_train_labels = bach["train"]
|
| 1252 |
+
|
| 1253 |
+
# Combine all datasets
|
| 1254 |
+
images = np.concatenate([bh_images, pc_train_images, b_train_images], axis=0)
|
| 1255 |
+
labels = np.concatenate([bh_labels, pc_train_labels, b_train_labels], axis=0)
|
| 1256 |
+
|
| 1257 |
+
print(f"Combined dataset: {len(images)} total images")
|
| 1258 |
+
print(f"Final distribution - Benign: {np.sum(labels == 0)}, Malignant: {np.sum(labels == 1)}")
|
| 1259 |
+
|
| 1260 |
+
# Shuffle combined data
|
| 1261 |
+
idx = np.arange(len(images))
|
| 1262 |
+
np.random.shuffle(idx)
|
| 1263 |
+
images, labels = images[idx], labels[idx]
|
| 1264 |
+
|
| 1265 |
+
# Split combined data
|
| 1266 |
+
X_temp, X_test, y_temp, y_test = train_test_split(
|
| 1267 |
+
images, labels, test_size=0.2,
|
| 1268 |
+
stratify=labels if len(set(labels)) > 1 else None,
|
| 1269 |
+
random_state=42
|
| 1270 |
+
)
|
| 1271 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 1272 |
+
X_temp, y_temp, test_size=0.25,
|
| 1273 |
+
stratify=y_temp if len(set(y_temp)) > 1 else None,
|
| 1274 |
+
random_state=42
|
| 1275 |
+
)
|
| 1276 |
+
|
| 1277 |
+
return {
|
| 1278 |
+
'train': (X_train, y_train),
|
| 1279 |
+
'valid': (X_val, y_val),
|
| 1280 |
+
'test': (X_test, y_test)
|
| 1281 |
+
}
|
| 1282 |
+
|
| 1283 |
+
else:
|
| 1284 |
+
raise ValueError(f"Unknown dataset choice: {dataset_choice}. "
|
| 1285 |
+
f"Choose from: 'breakhis', 'pcam', 'bach', 'combined'")
|
| 1286 |
+
|
| 1287 |
+
def main():
|
| 1288 |
+
"""
|
| 1289 |
+
Execute the complete breast cancer classification pipeline.
|
| 1290 |
+
|
| 1291 |
+
This function coordinates all components of the machine learning workflow:
|
| 1292 |
+
1. Environment validation and setup
|
| 1293 |
+
2. Model authentication and loading
|
| 1294 |
+
3. Dataset loading and preprocessing
|
| 1295 |
+
4. Feature extraction using Path Foundation
|
| 1296 |
+
5. Classifier training with advanced techniques
|
| 1297 |
+
6. Comprehensive model evaluation
|
| 1298 |
+
7. Model persistence for future use
|
| 1299 |
+
|
| 1300 |
+
The pipeline implements a robust transfer learning approach using Google's
|
| 1301 |
+
Path Foundation model as a feature extractor, followed by a trainable
|
| 1302 |
+
classification head for binary breast cancer classification.
|
| 1303 |
+
|
| 1304 |
+
Returns:
|
| 1305 |
+
tuple: (classifier_instance, evaluation_results) or (None, None) if failed
|
| 1306 |
+
- classifier_instance: Trained BreastCancerClassifier object
|
| 1307 |
+
- evaluation_results: Dictionary containing performance metrics and predictions
|
| 1308 |
+
|
| 1309 |
+
Configuration:
|
| 1310 |
+
The function uses global variables for configuration (can be modified):
|
| 1311 |
+
- DATASET_CHOICE: Dataset to use ("breakhis", "pcam", "bach", "combined")
|
| 1312 |
+
- MAX_SAMPLES: Maximum samples to load (adjust based on available memory)
|
| 1313 |
+
- EPOCHS: Number of training epochs (default: 50)
|
| 1314 |
+
- HF_TOKEN: Hugging Face authentication token (optional)
|
| 1315 |
+
|
| 1316 |
+
Pipeline Steps:
|
| 1317 |
+
1. Prerequisites Check: Validates required packages and dependencies
|
| 1318 |
+
2. Authentication: Authenticates with Hugging Face Hub
|
| 1319 |
+
3. Model Loading: Downloads and loads Path Foundation model
|
| 1320 |
+
4. Data Loading: Loads and preprocesses histopathology dataset
|
| 1321 |
+
5. Feature Extraction: Extracts embeddings using frozen foundation model
|
| 1322 |
+
6. Classifier Building: Constructs trainable classification head
|
| 1323 |
+
7. Training: Trains classifier with callbacks and monitoring
|
| 1324 |
+
8. Evaluation: Comprehensive performance assessment
|
| 1325 |
+
9. Model Saving: Persists trained model for future use
|
| 1326 |
+
|
| 1327 |
+
Error Handling:
|
| 1328 |
+
The function includes comprehensive error handling with detailed error messages
|
| 1329 |
+
and stack traces to aid in debugging and troubleshooting.
|
| 1330 |
+
|
| 1331 |
+
Example:
|
| 1332 |
+
>>> # Run the complete pipeline
|
| 1333 |
+
>>> classifier, results = main()
|
| 1334 |
+
>>>
|
| 1335 |
+
>>> if results:
|
| 1336 |
+
... print(f"Pipeline successful! Accuracy: {results['accuracy']:.4f}")
|
| 1337 |
+
... # Use the trained classifier for inference
|
| 1338 |
+
... else:
|
| 1339 |
+
... print("Pipeline failed - check error messages")
|
| 1340 |
+
|
| 1341 |
+
Note:
|
| 1342 |
+
This function is designed to be run as a standalone script or imported
|
| 1343 |
+
and called from other modules. It provides a complete end-to-end
|
| 1344 |
+
machine learning pipeline for breast cancer classification.
|
| 1345 |
+
"""
|
| 1346 |
+
print("="*60)
|
| 1347 |
+
print("BREAST CANCER CLASSIFICATION WITH PATH FOUNDATION")
|
| 1348 |
+
print("="*60)
|
| 1349 |
+
|
| 1350 |
+
# Validate prerequisites
|
| 1351 |
+
if not HF_AVAILABLE:
|
| 1352 |
+
print("ERROR: Prerequisites not met")
|
| 1353 |
+
print("Required installations: pip install tensorflow huggingface_hub transformers")
|
| 1354 |
+
return None, None
|
| 1355 |
+
|
| 1356 |
+
# Configuration parameters
|
| 1357 |
+
EPOCHS = 50
|
| 1358 |
+
HF_TOKEN = None # Set your Hugging Face token here if needed
|
| 1359 |
+
|
| 1360 |
+
# Global configuration (can be modified in notebook)
|
| 1361 |
+
if 'DATASET_CHOICE' not in globals():
|
| 1362 |
+
DATASET_CHOICE = 'combined' # Options: 'breakhis', 'pcam', 'bach', 'combined'
|
| 1363 |
+
if 'MAX_SAMPLES' not in globals():
|
| 1364 |
+
MAX_SAMPLES = 4000
|
| 1365 |
+
|
| 1366 |
+
print(f"Configuration:")
|
| 1367 |
+
print(f" - Epochs: {EPOCHS}")
|
| 1368 |
+
print(f" - Dataset: {DATASET_CHOICE}")
|
| 1369 |
+
print(f" - Max samples: {MAX_SAMPLES}")
|
| 1370 |
+
print(f" - Method: Feature extraction (frozen foundation model)")
|
| 1371 |
+
|
| 1372 |
+
try:
|
| 1373 |
+
# Initialize classifier in feature extraction mode
|
| 1374 |
+
classifier = BreastCancerClassifier(fine_tune=False)
|
| 1375 |
+
|
| 1376 |
+
print("\n" + "="*40)
|
| 1377 |
+
print("STEP 1: HUGGING FACE AUTHENTICATION")
|
| 1378 |
+
print("="*40)
|
| 1379 |
+
if not classifier.authenticate_huggingface(HF_TOKEN):
|
| 1380 |
+
raise Exception("Authentication failed - check your HF token")
|
| 1381 |
+
|
| 1382 |
+
print("\n" + "="*40)
|
| 1383 |
+
print("STEP 2: LOADING PATH FOUNDATION MODEL")
|
| 1384 |
+
print("="*40)
|
| 1385 |
+
if not classifier.load_path_foundation():
|
| 1386 |
+
raise Exception("Model loading failed - check network connection")
|
| 1387 |
+
|
| 1388 |
+
print("\n" + "="*40)
|
| 1389 |
+
print(f"STEP 3: LOADING {DATASET_CHOICE.upper()} DATASET")
|
| 1390 |
+
print("="*40)
|
| 1391 |
+
data = load_combined_data(DATASET_CHOICE, MAX_SAMPLES)
|
| 1392 |
+
|
| 1393 |
+
X_train, y_train = data['train']
|
| 1394 |
+
X_val, y_val = data['valid']
|
| 1395 |
+
X_test, y_test = data['test']
|
| 1396 |
+
|
| 1397 |
+
print(f"Dataset splits:")
|
| 1398 |
+
print(f" - Training: {len(X_train)} samples")
|
| 1399 |
+
print(f" - Validation: {len(X_val)} samples")
|
| 1400 |
+
print(f" - Test: {len(X_test)} samples")
|
| 1401 |
+
|
| 1402 |
+
print("\n" + "="*40)
|
| 1403 |
+
print("STEP 4: EXTRACTING FEATURE EMBEDDINGS")
|
| 1404 |
+
print("="*40)
|
| 1405 |
+
print("Extracting training embeddings...")
|
| 1406 |
+
X_train = classifier.extract_embeddings(X_train)
|
| 1407 |
+
print("Extracting validation embeddings...")
|
| 1408 |
+
X_val = classifier.extract_embeddings(X_val)
|
| 1409 |
+
print("Extracting test embeddings...")
|
| 1410 |
+
X_test = classifier.extract_embeddings(X_test)
|
| 1411 |
+
|
| 1412 |
+
print("\n" + "="*40)
|
| 1413 |
+
print("STEP 5: BUILDING CLASSIFICATION HEAD")
|
| 1414 |
+
print("="*40)
|
| 1415 |
+
classifier.num_classes = 2
|
| 1416 |
+
classifier.build_classifier()
|
| 1417 |
+
|
| 1418 |
+
print("\n" + "="*40)
|
| 1419 |
+
print("STEP 6: TRAINING CLASSIFIER")
|
| 1420 |
+
print("="*40)
|
| 1421 |
+
classifier.train_model(X_train, y_train, X_val, y_val, EPOCHS)
|
| 1422 |
+
|
| 1423 |
+
print("\n" + "="*40)
|
| 1424 |
+
print("STEP 7: MODEL EVALUATION")
|
| 1425 |
+
print("="*40)
|
| 1426 |
+
results = classifier.evaluate_model(X_test, y_test)
|
| 1427 |
+
|
| 1428 |
+
# Save trained model
|
| 1429 |
+
model_name = f"{DATASET_CHOICE}_breast_cancer_classifier.keras"
|
| 1430 |
+
classifier.model.save(model_name)
|
| 1431 |
+
print(f"\nModel saved as: {model_name}")
|
| 1432 |
+
|
| 1433 |
+
print("\n" + "="*60)
|
| 1434 |
+
print("PIPELINE COMPLETED SUCCESSFULLY")
|
| 1435 |
+
print("="*60)
|
| 1436 |
+
print(f"Final Performance Metrics:")
|
| 1437 |
+
print(f" - Accuracy: {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)")
|
| 1438 |
+
print(f" - F1-Score: {results['f1']:.4f}")
|
| 1439 |
+
print(f" - Precision: {results['precision']:.4f}")
|
| 1440 |
+
print(f" - Recall: {results['recall']:.4f}")
|
| 1441 |
+
|
| 1442 |
+
return classifier, results
|
| 1443 |
+
|
| 1444 |
+
except Exception as e:
|
| 1445 |
+
print(f"\nERROR: Pipeline failed - {e}")
|
| 1446 |
+
import traceback
|
| 1447 |
+
traceback.print_exc()
|
| 1448 |
+
return None, None
|
| 1449 |
+
|
| 1450 |
+
# Script execution section
|
| 1451 |
+
if __name__ == "__main__":
|
| 1452 |
+
"""
|
| 1453 |
+
Main execution block for running the breast cancer classification pipeline.
|
| 1454 |
+
|
| 1455 |
+
This section is executed when the script is run directly (not imported).
|
| 1456 |
+
It provides a simple interface to run the complete machine learning pipeline
|
| 1457 |
+
and displays the final results.
|
| 1458 |
+
|
| 1459 |
+
Usage:
|
| 1460 |
+
python model2.py
|
| 1461 |
+
|
| 1462 |
+
The script will:
|
| 1463 |
+
1. Initialize and run the complete pipeline
|
| 1464 |
+
2. Display progress and intermediate results
|
| 1465 |
+
3. Show final performance metrics
|
| 1466 |
+
4. Save the trained model for future use
|
| 1467 |
+
"""
|
| 1468 |
+
print("Starting Breast Cancer Classification Pipeline...")
|
| 1469 |
+
print("This may take several minutes depending on your hardware and dataset size.")
|
| 1470 |
+
print("="*60)
|
| 1471 |
+
|
| 1472 |
+
# Execute the complete pipeline
|
| 1473 |
+
classifier, results = main()
|
| 1474 |
+
|
| 1475 |
+
# Display final results
|
| 1476 |
+
if results:
|
| 1477 |
+
print("\n" + "="*60)
|
| 1478 |
+
print("🎉 PIPELINE EXECUTION SUCCESSFUL! 🎉")
|
| 1479 |
+
print("="*60)
|
| 1480 |
+
print(f"Final Accuracy: {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)")
|
| 1481 |
+
print(f"F1-Score: {results['f1']:.4f}")
|
| 1482 |
+
print(f"Precision: {results['precision']:.4f}")
|
| 1483 |
+
print(f"Recall: {results['recall']:.4f}")
|
| 1484 |
+
print("\nThe trained model has been saved and is ready for inference!")
|
| 1485 |
+
print("You can now use the classifier for breast cancer classification tasks.")
|
| 1486 |
+
else:
|
| 1487 |
+
print("\n" + "="*60)
|
| 1488 |
+
print("❌ PIPELINE EXECUTION FAILED ❌")
|
| 1489 |
+
print("="*60)
|
| 1490 |
+
print("Please check the error messages above for troubleshooting.")
|
| 1491 |
+
print("Common issues:")
|
| 1492 |
+
print("- Missing dependencies (install with: pip install tensorflow huggingface_hub transformers)")
|
| 1493 |
+
print("- Network connectivity issues (for downloading Path Foundation model)")
|
| 1494 |
+
print("- Insufficient memory (reduce MAX_SAMPLES parameter)")
|
| 1495 |
+
print("- Invalid dataset paths (check dataset directory structure)")
|
backend/requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
pillow
|
| 4 |
+
ultralytics
|
| 5 |
+
tensorflow
|
| 6 |
+
numpy
|
| 7 |
+
huggingface_hub
|
| 8 |
+
joblib
|
| 9 |
+
scikit-learn
|
| 10 |
+
scikit-image
|
| 11 |
+
loguru
|
| 12 |
+
thop
|
| 13 |
+
seaborn
|
| 14 |
+
python-multipart
|
backend/yolo_colposcopy.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f31fd29763c04b071b449db7bfa09743527f30a62913f25d5a6097366c4bf3b4
|
| 3 |
+
size 6235498
|
frontend/.eslintrc.cjs
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
module.exports = {
|
| 2 |
+
root: true,
|
| 3 |
+
env: { browser: true, es2020: true },
|
| 4 |
+
extends: [
|
| 5 |
+
'eslint:recommended',
|
| 6 |
+
'plugin:@typescript-eslint/recommended',
|
| 7 |
+
'plugin:react-hooks/recommended',
|
| 8 |
+
],
|
| 9 |
+
ignorePatterns: ['dist', '.eslintrc.cjs'],
|
| 10 |
+
parser: '@typescript-eslint/parser',
|
| 11 |
+
plugins: ['react-refresh'],
|
| 12 |
+
rules: {
|
| 13 |
+
'react-refresh/only-export-components': [
|
| 14 |
+
'warn',
|
| 15 |
+
{ allowConstantExport: true },
|
| 16 |
+
],
|
| 17 |
+
},
|
| 18 |
+
}
|
frontend/.gitignore
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Logs
|
| 2 |
+
logs
|
| 3 |
+
*.log
|
| 4 |
+
npm-debug.log*
|
| 5 |
+
yarn-debug.log*
|
| 6 |
+
yarn-error.log*
|
| 7 |
+
pnpm-debug.log*
|
| 8 |
+
lerna-debug.log*
|
| 9 |
+
|
| 10 |
+
node_modules
|
| 11 |
+
dist
|
| 12 |
+
dist-ssr
|
| 13 |
+
*.local
|
| 14 |
+
|
| 15 |
+
# Editor directories and files
|
| 16 |
+
.vscode/*
|
| 17 |
+
!.vscode/extensions.json
|
| 18 |
+
.idea
|
| 19 |
+
.DS_Store
|
| 20 |
+
*.suo
|
| 21 |
+
*.ntvs*
|
| 22 |
+
*.njsproj
|
| 23 |
+
*.sln
|
| 24 |
+
*.sw?
|
frontend/README.md
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Magic Patterns - Vite Template
|
| 2 |
+
|
| 3 |
+
This code was generated by [Magic Patterns](https://magicpatterns.com) for this design: [Source Design](https://www.magicpatterns.com/c/jitk86q9nv1at6tcwj7cxr)
|
| 4 |
+
|
| 5 |
+
## Getting Started
|
| 6 |
+
|
| 7 |
+
1. Run `npm install`
|
| 8 |
+
2. Run `npm run dev`
|
frontend/index.html
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8" />
|
| 5 |
+
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 7 |
+
<title>Manalife AI Pathology Assistant</title>
|
| 8 |
+
</head>
|
| 9 |
+
<body>
|
| 10 |
+
<div id="root"></div>
|
| 11 |
+
<script type="module" src="/src/index.tsx"></script>
|
| 12 |
+
</body>
|
| 13 |
+
</html>
|
frontend/package-lock.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
frontend/package.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "magic-patterns-vite-template",
|
| 3 |
+
"version": "0.0.1",
|
| 4 |
+
"private": true,
|
| 5 |
+
"type": "module",
|
| 6 |
+
"scripts": {
|
| 7 |
+
"dev": "npx vite",
|
| 8 |
+
"build": "npx vite build",
|
| 9 |
+
"lint": "eslint . --ext .js,.jsx,.ts,.tsx",
|
| 10 |
+
"preview": "npx vite preview"
|
| 11 |
+
},
|
| 12 |
+
"dependencies": {
|
| 13 |
+
"axios": "^1.12.2",
|
| 14 |
+
"lucide-react": "0.522.0",
|
| 15 |
+
"react": "^18.3.1",
|
| 16 |
+
"react-dom": "^18.3.1",
|
| 17 |
+
"react-router-dom": "^6.26.2"
|
| 18 |
+
},
|
| 19 |
+
"devDependencies": {
|
| 20 |
+
"@types/node": "^20.11.18",
|
| 21 |
+
"@types/react": "^18.3.1",
|
| 22 |
+
"@types/react-dom": "^18.3.1",
|
| 23 |
+
"@typescript-eslint/eslint-plugin": "^5.54.0",
|
| 24 |
+
"@typescript-eslint/parser": "^5.54.0",
|
| 25 |
+
"@vitejs/plugin-react": "^4.2.1",
|
| 26 |
+
"autoprefixer": "latest",
|
| 27 |
+
"eslint": "^8.50.0",
|
| 28 |
+
"eslint-plugin-react-hooks": "^4.6.0",
|
| 29 |
+
"eslint-plugin-react-refresh": "^0.4.1",
|
| 30 |
+
"postcss": "latest",
|
| 31 |
+
"tailwindcss": "3.4.17",
|
| 32 |
+
"typescript": "^5.5.4",
|
| 33 |
+
"vite": "^5.2.0"
|
| 34 |
+
}
|
| 35 |
+
}
|
frontend/postcss.config.js
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export default {
|
| 2 |
+
plugins: {
|
| 3 |
+
tailwindcss: {},
|
| 4 |
+
autoprefixer: {},
|
| 5 |
+
},
|
| 6 |
+
}
|
frontend/public/banner.jpeg
ADDED
|
frontend/public/black_logo.png
ADDED
|
frontend/public/colpo/colp1.jpg
ADDED
|
frontend/public/colpo/colp2.jpg
ADDED
|
frontend/public/colpo/colp3.jpg
ADDED
|
frontend/public/cyto/cyt1.jpg
ADDED
|
Git LFS Details
|
frontend/public/cyto/cyt2.png
ADDED
|
Git LFS Details
|
frontend/public/cyto/cyt3.png
ADDED
|
Git LFS Details
|
frontend/public/histo/hist1.png
ADDED
|
frontend/public/histo/hist2.png
ADDED
|
frontend/public/histo/hist3.jpg
ADDED
|
frontend/public/manalife_LOGO.jpg
ADDED
|
frontend/public/white_logo.png
ADDED
|
frontend/src/App.tsx
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useState, useEffect } from "react";
|
| 2 |
+
import axios from "axios";
|
| 3 |
+
import { Header } from "./components/Header";
|
| 4 |
+
import { Sidebar } from "./components/Sidebar";
|
| 5 |
+
import { UploadSection } from "./components/UploadSection";
|
| 6 |
+
import { ResultsPanel } from "./components/ResultsPanel";
|
| 7 |
+
import { Footer } from "./components/Footer";
|
| 8 |
+
import { ProgressBar } from "./components/progressbar";
|
| 9 |
+
|
| 10 |
+
export function App() {
|
| 11 |
+
// ----------------------------
|
| 12 |
+
// State Management
|
| 13 |
+
// ----------------------------
|
| 14 |
+
const [selectedTest, setSelectedTest] = useState("cytology");
|
| 15 |
+
const [uploadedImage, setUploadedImage] = useState<string | null>(null);
|
| 16 |
+
const [selectedModel, setSelectedModel] = useState("");
|
| 17 |
+
const [apiResult, setApiResult] = useState<any>(null);
|
| 18 |
+
const [showResults, setShowResults] = useState(false);
|
| 19 |
+
const [currentStep, setCurrentStep] = useState(0);
|
| 20 |
+
const [loading, setLoading] = useState(false);
|
| 21 |
+
|
| 22 |
+
// ----------------------------
|
| 23 |
+
// Progress bar logic
|
| 24 |
+
// ----------------------------
|
| 25 |
+
useEffect(() => {
|
| 26 |
+
if (showResults) setCurrentStep(2);
|
| 27 |
+
else if (uploadedImage) setCurrentStep(1);
|
| 28 |
+
else setCurrentStep(0);
|
| 29 |
+
}, [uploadedImage, showResults]);
|
| 30 |
+
|
| 31 |
+
// ----------------------------
|
| 32 |
+
// Reset logic — new test
|
| 33 |
+
// ----------------------------
|
| 34 |
+
useEffect(() => {
|
| 35 |
+
setCurrentStep(0);
|
| 36 |
+
setShowResults(false);
|
| 37 |
+
setUploadedImage(null);
|
| 38 |
+
setSelectedModel("");
|
| 39 |
+
setApiResult(null);
|
| 40 |
+
}, [selectedTest]);
|
| 41 |
+
|
| 42 |
+
// ----------------------------
|
| 43 |
+
// Analyze handler (Backend call)
|
| 44 |
+
// ----------------------------
|
| 45 |
+
const handleAnalyze = async () => {
|
| 46 |
+
if (!uploadedImage || !selectedModel) {
|
| 47 |
+
alert("Please select a model and upload an image first!");
|
| 48 |
+
return;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
setLoading(true);
|
| 53 |
+
setShowResults(false);
|
| 54 |
+
setApiResult(null);
|
| 55 |
+
|
| 56 |
+
try {
|
| 57 |
+
// Convert Base64 → File
|
| 58 |
+
const blob = await fetch(uploadedImage).then((r) => r.blob());
|
| 59 |
+
const file = new File([blob], "input.jpg", { type: blob.type });
|
| 60 |
+
|
| 61 |
+
const formData = new FormData();
|
| 62 |
+
formData.append("file", file);
|
| 63 |
+
formData.append("analysis_type", selectedTest);
|
| 64 |
+
formData.append("model_name", selectedModel);
|
| 65 |
+
|
| 66 |
+
// POST to backend
|
| 67 |
+
const baseURL =
|
| 68 |
+
import.meta.env.MODE === "development"
|
| 69 |
+
? "http://127.0.0.1:8000"
|
| 70 |
+
: window.location.origin;
|
| 71 |
+
|
| 72 |
+
const res = await axios.post(`${baseURL}/predict/`, formData, {
|
| 73 |
+
headers: { "Content-Type": "multipart/form-data" },
|
| 74 |
+
});
|
| 75 |
+
|
| 76 |
+
setApiResult(res.data);
|
| 77 |
+
setShowResults(true);
|
| 78 |
+
} catch (err) {
|
| 79 |
+
console.error("❌ Error during inference:", err);
|
| 80 |
+
alert("Error analyzing the image. Check backend logs.");
|
| 81 |
+
} finally {
|
| 82 |
+
setLoading(false);
|
| 83 |
+
}
|
| 84 |
+
};
|
| 85 |
+
// ----------------------------
|
| 86 |
+
// Layout
|
| 87 |
+
// ----------------------------
|
| 88 |
+
return ( <div className="flex flex-col min-h-screen w-full bg-gray-50"> <Header /> <ProgressBar currentStep={currentStep} />
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
<div className="flex flex-1">
|
| 92 |
+
<Sidebar selectedTest={selectedTest} onTestChange={setSelectedTest} />
|
| 93 |
+
|
| 94 |
+
<main className="flex-1 p-6">
|
| 95 |
+
<div className="max-w-7xl mx-auto grid grid-cols-1 lg:grid-cols-2 gap-6">
|
| 96 |
+
{/* Upload & Model Selection */}
|
| 97 |
+
<UploadSection
|
| 98 |
+
selectedTest={selectedTest}
|
| 99 |
+
uploadedImage={uploadedImage}
|
| 100 |
+
setUploadedImage={setUploadedImage}
|
| 101 |
+
selectedModel={selectedModel}
|
| 102 |
+
setSelectedModel={setSelectedModel}
|
| 103 |
+
onAnalyze={handleAnalyze}
|
| 104 |
+
/>
|
| 105 |
+
|
| 106 |
+
{/* Results Panel */}
|
| 107 |
+
{showResults && (
|
| 108 |
+
<ResultsPanel
|
| 109 |
+
uploadedImage={
|
| 110 |
+
apiResult?.annotated_image_url || uploadedImage
|
| 111 |
+
}
|
| 112 |
+
result={apiResult}
|
| 113 |
+
loading={loading}
|
| 114 |
+
/>
|
| 115 |
+
)}
|
| 116 |
+
</div>
|
| 117 |
+
</main>
|
| 118 |
+
</div>
|
| 119 |
+
|
| 120 |
+
<Footer />
|
| 121 |
+
</div>
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
);
|
| 125 |
+
}
|
frontend/src/AppRouter.tsx
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from "react";
|
| 2 |
+
import { BrowserRouter, Routes, Route } from "react-router-dom";
|
| 3 |
+
import { App } from "./App";
|
| 4 |
+
export function AppRouter() {
|
| 5 |
+
return <BrowserRouter>
|
| 6 |
+
<Routes>
|
| 7 |
+
<Route path="/" element={<App />} />
|
| 8 |
+
</Routes>
|
| 9 |
+
</BrowserRouter>;
|
| 10 |
+
}
|
frontend/src/components/Footer.tsx
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React from 'react';
|
| 2 |
+
|
| 3 |
+
export function Footer() {
|
| 4 |
+
return (
|
| 5 |
+
<footer
|
| 6 |
+
className="relative w-full text-white py-10 mt-auto bg-cover bg-center"
|
| 7 |
+
style={{
|
| 8 |
+
backgroundImage:
|
| 9 |
+
"url('banner.jpeg')",
|
| 10 |
+
}}
|
| 11 |
+
>
|
| 12 |
+
{/* Overlay for readability */}
|
| 13 |
+
<div className="absolute inset-0 bg-gradient-to-t from-slate-900/95 via-slate-900/70 to-transparent" />
|
| 14 |
+
|
| 15 |
+
{/* Main footer content */}
|
| 16 |
+
<div className="relative max-w-7xl mx-auto px-8">
|
| 17 |
+
<div className="flex justify-center gap-8 mb-4">
|
| 18 |
+
<a
|
| 19 |
+
href="#"
|
| 20 |
+
className="relative text-blue-300 hover:text-blue-100 transition-all duration-300 after:content-[''] after:absolute after:left-0 after:bottom-0 after:w-0 hover:after:w-full after:h-[1px] after:bg-blue-300 after:transition-all after:duration-300"
|
| 21 |
+
>
|
| 22 |
+
Help Center
|
| 23 |
+
</a>
|
| 24 |
+
<a
|
| 25 |
+
href="#"
|
| 26 |
+
className="relative text-blue-300 hover:text-blue-100 transition-all duration-300 after:content-[''] after:absolute after:left-0 after:bottom-0 after:w-0 hover:after:w-full after:h-[1px] after:bg-blue-300 after:transition-all after:duration-300"
|
| 27 |
+
>
|
| 28 |
+
Contact Support
|
| 29 |
+
</a>
|
| 30 |
+
</div>
|
| 31 |
+
|
| 32 |
+
<div className="text-center">
|
| 33 |
+
<p className="font-semibold mb-2">© 2025 Manalife. All rights reserved.</p>
|
| 34 |
+
<p className="text-gray-300 text-sm">
|
| 35 |
+
Advancing innovation in women's health and digital pathology.
|
| 36 |
+
</p>
|
| 37 |
+
</div>
|
| 38 |
+
</div>
|
| 39 |
+
|
| 40 |
+
{/* Logo at bottom-right corner */}
|
| 41 |
+
<div className="absolute bottom-4 right-8">
|
| 42 |
+
<img
|
| 43 |
+
src="/white_logo.png"
|
| 44 |
+
alt="Manalife Logo"
|
| 45 |
+
className="h-12 w-auto opacity-90 hover:opacity-100 transition-opacity duration-300"
|
| 46 |
+
/>
|
| 47 |
+
</div>
|
| 48 |
+
</footer>
|
| 49 |
+
);
|
| 50 |
+
}
|
frontend/src/components/Header.tsx
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export function Header() {
|
| 2 |
+
return (
|
| 3 |
+
<div className="w-full bg-white">
|
| 4 |
+
{/* Banner */}
|
| 5 |
+
<div
|
| 6 |
+
className="w-full h-24 bg-cover bg-center"
|
| 7 |
+
style={{
|
| 8 |
+
backgroundImage:
|
| 9 |
+
"url('/banner.jpeg')",
|
| 10 |
+
}}
|
| 11 |
+
>
|
| 12 |
+
<div className="w-full h-full bg-gradient-to-r from-blue-900/80 to-teal-900/80 flex items-center px-8">
|
| 13 |
+
{/* Logo + Title */}
|
| 14 |
+
<div className="flex items-center gap-4">
|
| 15 |
+
<img
|
| 16 |
+
src="/white_logo.png"
|
| 17 |
+
alt="Manalife Logo"
|
| 18 |
+
className="h-16 w-auto"
|
| 19 |
+
/>
|
| 20 |
+
<h1 className="text-white text-2xl font-semibold tracking-wide">
|
| 21 |
+
Manalife AI Pathology Assistant
|
| 22 |
+
</h1>
|
| 23 |
+
</div>
|
| 24 |
+
</div>
|
| 25 |
+
</div>
|
| 26 |
+
|
| 27 |
+
{/* Disclaimer */}
|
| 28 |
+
<div className="bg-blue-50 border-b border-blue-200 px-8 py-3">
|
| 29 |
+
<h3 className="font-semibold text-blue-900 mb-1">Public Disclaimer</h3>
|
| 30 |
+
<p className="text-sm text-blue-800">
|
| 31 |
+
Manalife AI models are research prototypes developed to advance
|
| 32 |
+
innovation in women's health and digital pathology. They are not
|
| 33 |
+
certified medical devices and are not intended for direct diagnosis or
|
| 34 |
+
treatment. Clinical validation and regulatory approval are required
|
| 35 |
+
before any medical use.
|
| 36 |
+
</p>
|
| 37 |
+
</div>
|
| 38 |
+
</div>
|
| 39 |
+
);
|
| 40 |
+
}
|
frontend/src/components/ResultsPanel.tsx
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { DownloadIcon, InfoIcon, Loader2Icon } from "lucide-react";
|
| 2 |
+
|
| 3 |
+
interface ResultsPanelProps {
|
| 4 |
+
uploadedImage: string | null;
|
| 5 |
+
result?: any;
|
| 6 |
+
loading?: boolean;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
export function ResultsPanel({ uploadedImage, result, loading }: ResultsPanelProps) {
|
| 10 |
+
if (loading) {
|
| 11 |
+
return ( <div className="bg-white rounded-lg shadow-sm p-6 flex flex-col items-center justify-center"> <Loader2Icon className="w-10 h-10 text-blue-600 animate-spin mb-3" /> <p className="text-gray-600 font-medium">Analyzing image...</p> </div>
|
| 12 |
+
);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
if (!result) {
|
| 16 |
+
return ( <div className="bg-white rounded-lg shadow-sm p-6 text-center text-gray-500">
|
| 17 |
+
No analysis result available yet. </div>
|
| 18 |
+
);
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
const {
|
| 22 |
+
prediction,
|
| 23 |
+
confidence,
|
| 24 |
+
probabilities,
|
| 25 |
+
detections,
|
| 26 |
+
summary,
|
| 27 |
+
annotated_image_url,
|
| 28 |
+
model_name,
|
| 29 |
+
analysis_type,
|
| 30 |
+
} = result;
|
| 31 |
+
|
| 32 |
+
const handleDownload = () => {
|
| 33 |
+
if (annotated_image_url) {
|
| 34 |
+
const link = document.createElement("a");
|
| 35 |
+
link.href = annotated_image_url;
|
| 36 |
+
link.download = "analysis_result.jpg";
|
| 37 |
+
link.click();
|
| 38 |
+
}
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
return ( <div className="bg-white rounded-lg shadow-sm p-6">
|
| 42 |
+
{/* Header */} <div className="flex items-center justify-between mb-6"> <div> <h2 className="text-2xl font-bold text-gray-800">
|
| 43 |
+
{model_name ? model_name.toUpperCase() : "Analysis Result"} </h2> <p className="text-sm text-gray-500 capitalize">
|
| 44 |
+
{analysis_type || "Test Type"} </p> </div>
|
| 45 |
+
{annotated_image_url && ( <button
|
| 46 |
+
onClick={handleDownload}
|
| 47 |
+
className="flex items-center gap-2 bg-green-600 text-white px-4 py-2 rounded-lg hover:bg-green-700 transition-colors"
|
| 48 |
+
> <DownloadIcon className="w-4 h-4" />
|
| 49 |
+
Download Image </button>
|
| 50 |
+
)} </div>
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
{/* Image */}
|
| 54 |
+
<div className="relative mb-6 rounded-lg overflow-hidden border border-gray-200">
|
| 55 |
+
<img
|
| 56 |
+
src={annotated_image_url || uploadedImage || "/ui.jpg"}
|
| 57 |
+
alt="Analysis Result"
|
| 58 |
+
className="w-full h-64 object-cover"
|
| 59 |
+
/>
|
| 60 |
+
</div>
|
| 61 |
+
|
| 62 |
+
{/* Results Summary */}
|
| 63 |
+
<div className="mb-6">
|
| 64 |
+
{prediction && (
|
| 65 |
+
<h3
|
| 66 |
+
className={`text-3xl font-bold ${
|
| 67 |
+
prediction.toLowerCase().includes("malignant") ||
|
| 68 |
+
prediction.toLowerCase().includes("abnormal")
|
| 69 |
+
? "text-red-600"
|
| 70 |
+
: "text-green-600"
|
| 71 |
+
}`}
|
| 72 |
+
>
|
| 73 |
+
{prediction}
|
| 74 |
+
</h3>
|
| 75 |
+
)}
|
| 76 |
+
|
| 77 |
+
{confidence && (
|
| 78 |
+
<div className="mt-2">
|
| 79 |
+
<div className="flex items-center justify-between mb-1">
|
| 80 |
+
<span className="font-semibold text-gray-900">
|
| 81 |
+
Confidence: {(confidence * 100).toFixed(2)}%
|
| 82 |
+
</span>
|
| 83 |
+
<InfoIcon className="w-4 h-4 text-gray-400" />
|
| 84 |
+
</div>
|
| 85 |
+
<div className="w-full h-3 bg-gray-200 rounded-full overflow-hidden">
|
| 86 |
+
<div
|
| 87 |
+
className={`h-full ${
|
| 88 |
+
confidence > 0.7
|
| 89 |
+
? "bg-green-500"
|
| 90 |
+
: confidence > 0.4
|
| 91 |
+
? "bg-yellow-500"
|
| 92 |
+
: "bg-red-500"
|
| 93 |
+
}`}
|
| 94 |
+
style={{ width: `${confidence * 100}%` }}
|
| 95 |
+
/>
|
| 96 |
+
</div>
|
| 97 |
+
</div>
|
| 98 |
+
)}
|
| 99 |
+
|
| 100 |
+
{summary && (
|
| 101 |
+
<p className="mt-4 text-gray-700 text-sm leading-relaxed">
|
| 102 |
+
{summary}
|
| 103 |
+
</p>
|
| 104 |
+
)}
|
| 105 |
+
</div>
|
| 106 |
+
|
| 107 |
+
{/* Detections / Probabilities */}
|
| 108 |
+
{detections && detections.length > 0 && (
|
| 109 |
+
<div className="mb-6">
|
| 110 |
+
<h4 className="font-semibold text-gray-900 mb-3">
|
| 111 |
+
Detected Regions:
|
| 112 |
+
</h4>
|
| 113 |
+
<ul className="text-sm text-gray-700 list-disc list-inside space-y-1">
|
| 114 |
+
{detections.map((det: any, i: number) => (
|
| 115 |
+
<li key={i}>
|
| 116 |
+
{det.name || "object"} – {(det.confidence * 100).toFixed(1)}%
|
| 117 |
+
</li>
|
| 118 |
+
))}
|
| 119 |
+
</ul>
|
| 120 |
+
</div>
|
| 121 |
+
)}
|
| 122 |
+
|
| 123 |
+
{probabilities && (
|
| 124 |
+
<div className="mb-6">
|
| 125 |
+
<h4 className="font-semibold text-gray-900 mb-3">
|
| 126 |
+
Class Probabilities:
|
| 127 |
+
</h4>
|
| 128 |
+
<pre className="bg-gray-100 rounded-lg p-3 text-sm">
|
| 129 |
+
{JSON.stringify(probabilities, null, 2)}
|
| 130 |
+
</pre>
|
| 131 |
+
</div>
|
| 132 |
+
)}
|
| 133 |
+
|
| 134 |
+
{/* Report Button */}
|
| 135 |
+
<button className="w-full bg-blue-600 text-white py-3 rounded-lg font-medium hover:bg-blue-700 transition-colors flex items-center justify-center gap-2">
|
| 136 |
+
<DownloadIcon className="w-5 h-5" />
|
| 137 |
+
Generate Report
|
| 138 |
+
</button>
|
| 139 |
+
</div>
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
);
|
| 143 |
+
}
|
frontend/src/components/Sidebar.tsx
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useState } from 'react';
|
| 2 |
+
import { ChevronDownIcon, FileTextIcon, HelpCircleIcon } from 'lucide-react';
|
| 3 |
+
interface SidebarProps {
|
| 4 |
+
selectedTest: string;
|
| 5 |
+
onTestChange: (test: string) => void;
|
| 6 |
+
}
|
| 7 |
+
export function Sidebar({
|
| 8 |
+
selectedTest,
|
| 9 |
+
onTestChange
|
| 10 |
+
}: SidebarProps) {
|
| 11 |
+
const [isDropdownOpen, setIsDropdownOpen] = useState(false);
|
| 12 |
+
const testTypes = [{
|
| 13 |
+
value: 'cytology',
|
| 14 |
+
label: 'Cytology Analysis'
|
| 15 |
+
}, {
|
| 16 |
+
value: 'colposcopy',
|
| 17 |
+
label: 'Colposcopy Analysis'
|
| 18 |
+
}, {
|
| 19 |
+
value: 'histopathology',
|
| 20 |
+
label: 'Histopathology Analysis'
|
| 21 |
+
}];
|
| 22 |
+
return <aside className="w-64 bg-white border-r border-gray-200 p-4">
|
| 23 |
+
<div className="space-y-2">
|
| 24 |
+
{/* New Test Dropdown */}
|
| 25 |
+
<div className="relative">
|
| 26 |
+
<button onClick={() => setIsDropdownOpen(!isDropdownOpen)} className="w-full bg-blue-600 text-white rounded-lg px-4 py-3 flex items-center justify-between hover:bg-blue-700 transition-colors">
|
| 27 |
+
<div className="flex items-center gap-2">
|
| 28 |
+
<div className="w-2 h-2 bg-white rounded-full" />
|
| 29 |
+
<span className="font-medium">New Test</span>
|
| 30 |
+
</div>
|
| 31 |
+
<ChevronDownIcon className={`w-5 h-5 transition-transform ${isDropdownOpen ? 'rotate-180' : ''}`} />
|
| 32 |
+
</button>
|
| 33 |
+
{isDropdownOpen && <div className="absolute top-full left-0 right-0 mt-2 bg-white border border-gray-200 rounded-lg shadow-lg z-10">
|
| 34 |
+
{testTypes.map(test => <button key={test.value} onClick={() => {
|
| 35 |
+
onTestChange(test.value);
|
| 36 |
+
setIsDropdownOpen(false);
|
| 37 |
+
}} className={`w-full text-left px-4 py-3 hover:bg-gray-50 transition-colors ${selectedTest === test.value ? 'bg-blue-50 text-blue-600' : 'text-gray-700'}`}>
|
| 38 |
+
{test.label}
|
| 39 |
+
</button>)}
|
| 40 |
+
</div>}
|
| 41 |
+
</div>
|
| 42 |
+
{/* History */}
|
| 43 |
+
<button className="w-full flex items-center gap-3 px-4 py-3 text-gray-700 hover:bg-gray-50 rounded-lg transition-colors">
|
| 44 |
+
<FileTextIcon className="w-5 h-5" />
|
| 45 |
+
<span>History</span>
|
| 46 |
+
</button>
|
| 47 |
+
{/* Help */}
|
| 48 |
+
<button className="w-full flex items-center gap-3 px-4 py-3 text-gray-700 hover:bg-gray-50 rounded-lg transition-colors">
|
| 49 |
+
<HelpCircleIcon className="w-5 h-5" />
|
| 50 |
+
<span>Help</span>
|
| 51 |
+
</button>
|
| 52 |
+
</div>
|
| 53 |
+
</aside>;
|
| 54 |
+
}
|
frontend/src/components/UploadSection.tsx
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useRef } from 'react';
|
| 2 |
+
import { UploadIcon } from 'lucide-react';
|
| 3 |
+
|
| 4 |
+
interface UploadSectionProps {
|
| 5 |
+
selectedTest: string;
|
| 6 |
+
uploadedImage: string | null;
|
| 7 |
+
setUploadedImage: (image: string | null) => void;
|
| 8 |
+
selectedModel: string;
|
| 9 |
+
setSelectedModel: (model: string) => void;
|
| 10 |
+
onAnalyze: () => void;
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
export function UploadSection({
|
| 16 |
+
selectedTest,
|
| 17 |
+
uploadedImage,
|
| 18 |
+
setUploadedImage,
|
| 19 |
+
selectedModel,
|
| 20 |
+
setSelectedModel,
|
| 21 |
+
onAnalyze,
|
| 22 |
+
}: UploadSectionProps) {
|
| 23 |
+
const fileInputRef = useRef<HTMLInputElement>(null);
|
| 24 |
+
|
| 25 |
+
const modelOptions = {
|
| 26 |
+
cytology: [
|
| 27 |
+
{ value: 'mwt', label: 'MWT' },
|
| 28 |
+
{ value: 'yolo', label: 'YOLOv8' },
|
| 29 |
+
],
|
| 30 |
+
colposcopy: [
|
| 31 |
+
{ value: 'cin', label: 'Logistic-Colpo' },
|
| 32 |
+
|
| 33 |
+
],
|
| 34 |
+
histopathology: [
|
| 35 |
+
{ value: 'histopathology', label: 'Path Foundation Model' },
|
| 36 |
+
|
| 37 |
+
],
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
const sampleImages = {
|
| 42 |
+
cytology: [
|
| 43 |
+
"/cyto/cyt1.jpg",
|
| 44 |
+
"/cyto/cyt2.png",
|
| 45 |
+
"/cyto/cyt3.png",
|
| 46 |
+
],
|
| 47 |
+
colposcopy: [
|
| 48 |
+
"/colpo/colp1.jpg",
|
| 49 |
+
"/colpo/colp2.jpg",
|
| 50 |
+
"/colpo/colp3.jpg",
|
| 51 |
+
],
|
| 52 |
+
histopathology: [
|
| 53 |
+
"/histo/hist1.png",
|
| 54 |
+
"/histo/hist2.png",
|
| 55 |
+
"/histo/hist3.jpg",
|
| 56 |
+
],
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
const currentModels =
|
| 60 |
+
modelOptions[selectedTest as keyof typeof modelOptions] || [];
|
| 61 |
+
|
| 62 |
+
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
| 63 |
+
const file = e.target.files?.[0];
|
| 64 |
+
if (file) {
|
| 65 |
+
const reader = new FileReader();
|
| 66 |
+
reader.onload = (event) => {
|
| 67 |
+
setUploadedImage(event.target?.result as string);
|
| 68 |
+
};
|
| 69 |
+
reader.readAsDataURL(file);
|
| 70 |
+
}
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
const handleDrop = (e: React.DragEvent) => {
|
| 74 |
+
e.preventDefault();
|
| 75 |
+
const file = e.dataTransfer.files[0];
|
| 76 |
+
if (file) {
|
| 77 |
+
const reader = new FileReader();
|
| 78 |
+
reader.onload = (event) => {
|
| 79 |
+
setUploadedImage(event.target?.result as string);
|
| 80 |
+
};
|
| 81 |
+
reader.readAsDataURL(file);
|
| 82 |
+
}
|
| 83 |
+
};
|
| 84 |
+
|
| 85 |
+
// Handle click on sample image
|
| 86 |
+
const handleSampleClick = (imgUrl: string) => {
|
| 87 |
+
setUploadedImage(imgUrl);
|
| 88 |
+
};
|
| 89 |
+
|
| 90 |
+
return (
|
| 91 |
+
<div className="bg-white rounded-lg shadow-sm p-6">
|
| 92 |
+
<h2 className="text-2xl font-semibold text-gray-900 mb-6">
|
| 93 |
+
Upload an image of a tissue sample
|
| 94 |
+
</h2>
|
| 95 |
+
|
| 96 |
+
{/* Upload Area */}
|
| 97 |
+
<div
|
| 98 |
+
onDrop={handleDrop}
|
| 99 |
+
onDragOver={(e) => e.preventDefault()}
|
| 100 |
+
onClick={() => fileInputRef.current?.click()}
|
| 101 |
+
className="border-2 border-dashed border-gray-300 rounded-lg p-8 text-center cursor-pointer hover:border-blue-400 transition-colors"
|
| 102 |
+
>
|
| 103 |
+
<input
|
| 104 |
+
ref={fileInputRef}
|
| 105 |
+
type="file"
|
| 106 |
+
accept="image/*"
|
| 107 |
+
onChange={handleFileChange}
|
| 108 |
+
className="hidden"
|
| 109 |
+
/>
|
| 110 |
+
<div className="flex flex-col items-center">
|
| 111 |
+
<div className="w-16 h-16 bg-gray-100 rounded-full flex items-center justify-center mb-4">
|
| 112 |
+
<UploadIcon className="w-8 h-8 text-gray-400" />
|
| 113 |
+
</div>
|
| 114 |
+
{uploadedImage ? (
|
| 115 |
+
<>
|
| 116 |
+
<p className="text-green-600 font-medium mb-2">
|
| 117 |
+
Image uploaded successfully!
|
| 118 |
+
</p>
|
| 119 |
+
<p className="text-sm text-gray-500 mb-4">
|
| 120 |
+
Click to upload a different image
|
| 121 |
+
</p>
|
| 122 |
+
<div className="w-32 h-32 rounded-lg overflow-hidden border border-gray-200">
|
| 123 |
+
<img
|
| 124 |
+
src={uploadedImage}
|
| 125 |
+
alt="Uploaded sample"
|
| 126 |
+
className="w-full h-full object-cover"
|
| 127 |
+
/>
|
| 128 |
+
</div>
|
| 129 |
+
</>
|
| 130 |
+
) : (
|
| 131 |
+
<p className="text-gray-600">Drag and drop or click to upload</p>
|
| 132 |
+
)}
|
| 133 |
+
</div>
|
| 134 |
+
</div>
|
| 135 |
+
|
| 136 |
+
{/* Model Selection */}
|
| 137 |
+
<div className="mt-6">
|
| 138 |
+
<label className="block text-sm font-medium text-gray-700 mb-2">
|
| 139 |
+
Select Analysis Model:
|
| 140 |
+
</label>
|
| 141 |
+
<select
|
| 142 |
+
value={selectedModel}
|
| 143 |
+
onChange={(e) => setSelectedModel(e.target.value)}
|
| 144 |
+
className="w-full px-4 py-3 border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent"
|
| 145 |
+
>
|
| 146 |
+
<option value="">Choose a model...</option>
|
| 147 |
+
{currentModels.map((model) => (
|
| 148 |
+
<option key={model.value} value={model.value}>
|
| 149 |
+
{model.label}
|
| 150 |
+
</option>
|
| 151 |
+
))}
|
| 152 |
+
</select>
|
| 153 |
+
</div>
|
| 154 |
+
|
| 155 |
+
{/* Analyze Button */}
|
| 156 |
+
<button
|
| 157 |
+
onClick={onAnalyze}
|
| 158 |
+
disabled={!uploadedImage || !selectedModel}
|
| 159 |
+
className="w-full mt-6 bg-blue-600 text-white py-3 rounded-lg font-medium hover:bg-blue-700 disabled:bg-gray-300 disabled:cursor-not-allowed transition-colors"
|
| 160 |
+
>
|
| 161 |
+
Analyze
|
| 162 |
+
</button>
|
| 163 |
+
|
| 164 |
+
{/* Separator */}
|
| 165 |
+
<hr className="my-8 border-gray-200" />
|
| 166 |
+
|
| 167 |
+
{/* Sample Images Section */}
|
| 168 |
+
<div>
|
| 169 |
+
<h3 className="text-lg font-semibold text-gray-800 mb-4">
|
| 170 |
+
Samples Images
|
| 171 |
+
</h3>
|
| 172 |
+
<div className="flex flex-wrap gap-4">
|
| 173 |
+
{(sampleImages[selectedTest as keyof typeof sampleImages] || []).map(
|
| 174 |
+
(img, index) => (
|
| 175 |
+
<div
|
| 176 |
+
key={index}
|
| 177 |
+
className={`w-20 h-20 rounded-lg border-2 cursor-pointer transition-transform hover:scale-105 hover:border-blue-500 overflow-hidden ${
|
| 178 |
+
uploadedImage === img ? 'border-blue-600' : 'border-gray-300'
|
| 179 |
+
}`}
|
| 180 |
+
onClick={() => handleSampleClick(img)}
|
| 181 |
+
>
|
| 182 |
+
<img
|
| 183 |
+
src={img}
|
| 184 |
+
alt={`Sample ${index + 1}`}
|
| 185 |
+
className="w-full h-full object-cover"
|
| 186 |
+
/>
|
| 187 |
+
</div>
|
| 188 |
+
)
|
| 189 |
+
)}
|
| 190 |
+
</div>
|
| 191 |
+
</div>
|
| 192 |
+
</div>
|
| 193 |
+
);
|
| 194 |
+
}
|
frontend/src/components/progressbar.tsx
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { Fragment } from "react";
|
| 2 |
+
import { CheckIcon, FileTextIcon } from "lucide-react";
|
| 3 |
+
|
| 4 |
+
interface ProgressBarProps {
|
| 5 |
+
currentStep: number;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
export function ProgressBar({ currentStep }: ProgressBarProps) {
|
| 9 |
+
const steps = [
|
| 10 |
+
{ label: "Upload", index: 0 },
|
| 11 |
+
{ label: "Analyze", index: 1 },
|
| 12 |
+
{ label: "Report", index: 2 },
|
| 13 |
+
];
|
| 14 |
+
|
| 15 |
+
return (
|
| 16 |
+
<div className="bg-white px-8 py-6 border-b border-gray-200">
|
| 17 |
+
<div className="max-w-2xl mx-auto flex items-center justify-center">
|
| 18 |
+
{steps.map((step, index) => (
|
| 19 |
+
<Fragment key={step.label}>
|
| 20 |
+
<div className="flex flex-col items-center">
|
| 21 |
+
{/* Step circle */}
|
| 22 |
+
<div
|
| 23 |
+
className={`w-12 h-12 rounded-full flex items-center justify-center text-white font-semibold transition-all duration-300 ${
|
| 24 |
+
currentStep > step.index
|
| 25 |
+
? "bg-gradient-to-r from-blue-800 to-teal-600"
|
| 26 |
+
: currentStep === step.index
|
| 27 |
+
? "bg-gradient-to-r from-blue-600 to-teal-500"
|
| 28 |
+
: "bg-gray-300 text-gray-600"
|
| 29 |
+
}`}
|
| 30 |
+
>
|
| 31 |
+
{currentStep > step.index ? (
|
| 32 |
+
<CheckIcon className="w-6 h-6" />
|
| 33 |
+
) : step.index === 2 ? (
|
| 34 |
+
<FileTextIcon className="w-6 h-6" />
|
| 35 |
+
) : (
|
| 36 |
+
<span>{index + 1}</span>
|
| 37 |
+
)}
|
| 38 |
+
</div>
|
| 39 |
+
|
| 40 |
+
{/* Label */}
|
| 41 |
+
<span className="mt-2 text-sm font-medium text-gray-700">
|
| 42 |
+
{step.label}
|
| 43 |
+
</span>
|
| 44 |
+
</div>
|
| 45 |
+
|
| 46 |
+
{/* Connecting line */}
|
| 47 |
+
{index < steps.length - 1 && (
|
| 48 |
+
<div
|
| 49 |
+
className={`h-1 w-32 mx-4 rounded-full transition-all duration-300 ${
|
| 50 |
+
currentStep > step.index
|
| 51 |
+
? "bg-gradient-to-r from-blue-800 to-teal-600"
|
| 52 |
+
: "bg-gray-300"
|
| 53 |
+
}`}
|
| 54 |
+
/>
|
| 55 |
+
)}
|
| 56 |
+
</Fragment>
|
| 57 |
+
))}
|
| 58 |
+
</div>
|
| 59 |
+
</div>
|
| 60 |
+
);
|
| 61 |
+
}
|
frontend/src/index.css
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* PLEASE NOTE: THESE TAILWIND IMPORTS SHOULD NEVER BE DELETED */
|
| 2 |
+
@import 'tailwindcss/base';
|
| 3 |
+
@import 'tailwindcss/components';
|
| 4 |
+
@import 'tailwindcss/utilities';
|
| 5 |
+
/* DO NOT DELETE THESE TAILWIND IMPORTS, OTHERWISE THE STYLING WILL NOT RENDER AT ALL */
|
frontend/src/index.tsx
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import './index.css';
|
| 2 |
+
import React from "react";
|
| 3 |
+
import { render } from "react-dom";
|
| 4 |
+
import { App } from "./App";
|
| 5 |
+
render(<App />, document.getElementById("root"));
|