malavikapradeep2001 commited on
Commit
bf5da6b
·
1 Parent(s): 7f72e7d

Initial Space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. Dockerfile +43 -0
  3. README.md +3 -4
  4. backend/MWTclass2.pth +3 -0
  5. backend/__pycache__/app0.cpython-312.pyc +0 -0
  6. backend/__pycache__/app2.cpython-312.pyc +0 -0
  7. backend/__pycache__/app3.cpython-312.pyc +0 -0
  8. backend/__pycache__/app4.cpython-312.pyc +0 -0
  9. backend/__pycache__/augmentations.cpython-312.pyc +0 -0
  10. backend/__pycache__/model.cpython-312.pyc +0 -0
  11. backend/__pycache__/model_histo.cpython-312.pyc +0 -0
  12. backend/app.py +241 -0
  13. backend/augmentations.py +328 -0
  14. backend/best2.pt +3 -0
  15. backend/histopathology_trained_model.keras +3 -0
  16. backend/logistic_regression_model.pkl +3 -0
  17. backend/model.py +521 -0
  18. backend/model_histo.py +1495 -0
  19. backend/requirements.txt +14 -0
  20. backend/yolo_colposcopy.pt +3 -0
  21. frontend/.eslintrc.cjs +18 -0
  22. frontend/.gitignore +24 -0
  23. frontend/README.md +8 -0
  24. frontend/index.html +13 -0
  25. frontend/package-lock.json +0 -0
  26. frontend/package.json +35 -0
  27. frontend/postcss.config.js +6 -0
  28. frontend/public/banner.jpeg +0 -0
  29. frontend/public/black_logo.png +0 -0
  30. frontend/public/colpo/colp1.jpg +0 -0
  31. frontend/public/colpo/colp2.jpg +0 -0
  32. frontend/public/colpo/colp3.jpg +0 -0
  33. frontend/public/cyto/cyt1.jpg +3 -0
  34. frontend/public/cyto/cyt2.png +3 -0
  35. frontend/public/cyto/cyt3.png +3 -0
  36. frontend/public/histo/hist1.png +0 -0
  37. frontend/public/histo/hist2.png +0 -0
  38. frontend/public/histo/hist3.jpg +0 -0
  39. frontend/public/manalife_LOGO.jpg +0 -0
  40. frontend/public/white_logo.png +0 -0
  41. frontend/src/App.tsx +125 -0
  42. frontend/src/AppRouter.tsx +10 -0
  43. frontend/src/components/Footer.tsx +50 -0
  44. frontend/src/components/Header.tsx +40 -0
  45. frontend/src/components/ResultsPanel.tsx +143 -0
  46. frontend/src/components/Sidebar.tsx +54 -0
  47. frontend/src/components/UploadSection.tsx +194 -0
  48. frontend/src/components/progressbar.tsx +61 -0
  49. frontend/src/index.css +5 -0
  50. frontend/src/index.tsx +5 -0
.gitattributes CHANGED
@@ -33,5 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
36
  backend/outputs/** filter=lfs diff=lfs merge=lfs -text
37
  frontend/public/cyto/** filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ backend/histopathology_trained_model.keras filter=lfs diff=lfs merge=lfs -text
37
+ frontend/public/cyto/*.jpg filter=lfs diff=lfs merge=lfs -text
38
+ frontend/public/cyto/*.png filter=lfs diff=lfs merge=lfs -text
39
  backend/outputs/** filter=lfs diff=lfs merge=lfs -text
40
  frontend/public/cyto/** filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------
2
+ # 1️⃣ Build Frontend
3
+ # -----------------------------
4
+ FROM node:18-bullseye AS frontend-builder
5
+ WORKDIR /app/frontend
6
+ COPY frontend/package*.json ./
7
+ RUN npm install
8
+ COPY frontend/ .
9
+ RUN npm run build
10
+
11
+ # -----------------------------
12
+ # 2️⃣ Build Backend
13
+ # -----------------------------
14
+ FROM python:3.10-slim-bullseye
15
+
16
+ WORKDIR /app
17
+
18
+ # Install essential system libraries for OpenCV, YOLO, and Ultralytics
19
+ RUN apt-get update && apt-get install -y \
20
+ libgl1 \
21
+ libglib2.0-0 \
22
+ libgomp1 \
23
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
24
+
25
+ # Copy backend source code
26
+ COPY backend/ .
27
+
28
+ # Copy built frontend into the right folder for FastAPI
29
+ # ✅ this must match your app.mount() path in app.py
30
+ COPY --from=frontend-builder /app/frontend/dist ./frontend/dist
31
+
32
+ # Install Python dependencies
33
+ RUN pip install --upgrade pip
34
+ RUN pip install -r requirements.txt || pip install -r backend/requirements.txt || true
35
+
36
+ # Install runtime dependencies explicitly
37
+ RUN pip install --no-cache-dir fastapi uvicorn python-multipart ultralytics opencv-python-headless pillow numpy scikit-learn tensorflow keras
38
+
39
+ # Hugging Face Spaces expect port 7860
40
+ EXPOSE 7860
41
+
42
+ # Run FastAPI app
43
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,10 @@
1
  ---
2
- title: Pathora
3
- emoji: 🐨
4
- colorFrom: red
5
  colorTo: green
6
  sdk: docker
7
  pinned: false
8
- short_description: Manalife's AI Pathology Assistant
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Proj Demo
3
+ emoji: 🚀
4
+ colorFrom: purple
5
  colorTo: green
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
backend/MWTclass2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09fae21e537056bc61f9ab4bb08249b591697bb087546cf639c599c70b8c6a2c
3
+ size 79777797
backend/__pycache__/app0.cpython-312.pyc ADDED
Binary file (13 kB). View file
 
backend/__pycache__/app2.cpython-312.pyc ADDED
Binary file (3.66 kB). View file
 
backend/__pycache__/app3.cpython-312.pyc ADDED
Binary file (4 kB). View file
 
backend/__pycache__/app4.cpython-312.pyc ADDED
Binary file (9.17 kB). View file
 
backend/__pycache__/augmentations.cpython-312.pyc ADDED
Binary file (20.9 kB). View file
 
backend/__pycache__/model.cpython-312.pyc ADDED
Binary file (31.3 kB). View file
 
backend/__pycache__/model_histo.cpython-312.pyc ADDED
Binary file (69.2 kB). View file
 
backend/app.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from ultralytics import YOLO
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ import uvicorn
9
+ import json, os, uuid, numpy as np, torch, cv2, joblib, io, tensorflow as tf
10
+ import torch.nn as nn
11
+ import torchvision.transforms as transforms
12
+ import torchvision.models as models
13
+ from sklearn.preprocessing import MinMaxScaler
14
+ from model import MWT as create_model
15
+ from augmentations import Augmentations
16
+ from model_histo import BreastCancerClassifier # TensorFlow model
17
+
18
+ from huggingface_hub import login
19
+ import os
20
+
21
+ hf_token = os.getenv("HF_TOKEN")
22
+ if hf_token:
23
+ login(token=hf_token)
24
+
25
+
26
+ # =====================================================
27
+ # App setup
28
+ # =====================================================
29
+ app = FastAPI(title="Unified Cervical & Breast Cancer Analysis API")
30
+
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["*"],
34
+ allow_credentials=True,
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
+ )
38
+
39
+ OUTPUT_DIR = "outputs"
40
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
41
+ app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
42
+
43
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+
45
+ # =====================================================
46
+ # Model 1: YOLO (Colposcopy Detection)
47
+ # =====================================================
48
+ print("🔹 Loading YOLO model...")
49
+ yolo_model = YOLO("best2.pt")
50
+
51
+ # =====================================================
52
+ # Model 2: MWT Classifier
53
+ # =====================================================
54
+ print("🔹 Loading MWT model...")
55
+ mwt_model = create_model(num_classes=2).to(device)
56
+ mwt_model.load_state_dict(torch.load("MWTclass2.pth", map_location=device))
57
+ mwt_model.eval()
58
+ mwt_class_names = ['neg', 'pos']
59
+
60
+ # =====================================================
61
+ # Model 3: CIN Classifier
62
+ # =====================================================
63
+ print("🔹 Loading CIN model...")
64
+ clf = joblib.load("logistic_regression_model.pkl")
65
+ yolo_colposcopy = YOLO("yolo_colposcopy.pt")
66
+
67
+ def build_resnet(model_name="resnet50"):
68
+ if model_name == "resnet50":
69
+ model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
70
+ elif model_name == "resnet101":
71
+ model = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
72
+ elif model_name == "resnet152":
73
+ model = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
74
+ model.eval().to(device)
75
+ return (
76
+ nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool),
77
+ model.layer1, model.layer2, model.layer3, model.layer4,
78
+ )
79
+
80
+ gap = nn.AdaptiveAvgPool2d((1, 1))
81
+ gmp = nn.AdaptiveMaxPool2d((1, 1))
82
+ resnet50_blocks = build_resnet("resnet50")
83
+ resnet101_blocks = build_resnet("resnet101")
84
+ resnet152_blocks = build_resnet("resnet152")
85
+
86
+ transform = transforms.Compose([
87
+ transforms.ToPILImage(),
88
+ transforms.Resize((224, 224)),
89
+ transforms.ToTensor(),
90
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
91
+ std=[0.229, 0.224, 0.225]),
92
+ ])
93
+
94
+ # =====================================================
95
+ # Model 4: Histopathology Classifier (TensorFlow)
96
+ # =====================================================
97
+ print("🔹 Loading Breast Cancer Histopathology model...")
98
+ classifier = BreastCancerClassifier(fine_tune=False)
99
+ if not classifier.authenticate_huggingface():
100
+ raise RuntimeError("HuggingFace authentication failed.")
101
+ if not classifier.load_path_foundation():
102
+ raise RuntimeError("Failed to load Path Foundation model.")
103
+ model_path = "histopathology_trained_model.keras"
104
+ classifier.model = tf.keras.models.load_model(model_path)
105
+ print(f"✅ Loaded model from {model_path}")
106
+
107
+ # =====================================================
108
+ # Helper functions
109
+ # =====================================================
110
+ def preprocess_for_mwt(image_np):
111
+ img = cv2.resize(image_np, (224, 224))
112
+ img = Augmentations.Normalization((0, 1))(img)
113
+ img = np.array(img, np.float32)
114
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
115
+ img = img.transpose(2, 0, 1)
116
+ img = np.expand_dims(img, axis=0)
117
+ return torch.Tensor(img)
118
+
119
+ def extract_cbf_features(blocks, img_t):
120
+ block1, block2, block3, block4, block5 = blocks
121
+ with torch.no_grad():
122
+ f1 = block1(img_t)
123
+ f2 = block2(f1)
124
+ f3 = block3(f2)
125
+ f4 = block4(f3)
126
+ f5 = block5(f4)
127
+ p1 = gmp(f1).view(-1)
128
+ p2 = gmp(f2).view(-1)
129
+ p3 = gap(f3).view(-1)
130
+ p4 = gap(f4).view(-1)
131
+ p5 = gap(f5).view(-1)
132
+ cbf_feature = torch.cat([p1, p2, p3, p4, p5], dim=0)
133
+ return cbf_feature.cpu().numpy()
134
+
135
+ def predict_histopathology(image: Image.Image):
136
+ if image.mode != "RGB":
137
+ image = image.convert("RGB")
138
+ image = image.resize((224, 224))
139
+ img_array = np.expand_dims(np.array(image).astype("float32") / 255.0, axis=0)
140
+ embeddings = classifier.extract_embeddings(img_array)
141
+ prediction_proba = classifier.model.predict(embeddings, verbose=0)[0]
142
+ predicted_class = int(np.argmax(prediction_proba))
143
+ class_names = ["Benign", "Malignant"]
144
+ return {
145
+ "model_used": "Breast Cancer Histopathology Classifier",
146
+ "prediction": class_names[predicted_class],
147
+ "confidence": float(np.max(prediction_proba)),
148
+ "probabilities": {
149
+ "Benign": float(prediction_proba[0]),
150
+ "Malignant": float(prediction_proba[1])
151
+ }
152
+ }
153
+
154
+ # =====================================================
155
+ # Main endpoint
156
+ # =====================================================
157
+ @app.post("/predict/")
158
+ async def predict(model_name: str = Form(...), file: UploadFile = File(...)):
159
+ contents = await file.read()
160
+ image = Image.open(BytesIO(contents)).convert("RGB")
161
+ image_np = np.array(image)
162
+
163
+ if model_name == "yolo":
164
+ results = yolo_model(image)
165
+ detections_json = results[0].to_json()
166
+ detections = json.loads(detections_json)
167
+ output_filename = f"detected_{uuid.uuid4().hex[:8]}.jpg"
168
+ output_path = os.path.join(OUTPUT_DIR, output_filename)
169
+ results[0].save(filename=output_path)
170
+ return {
171
+ "model_used": "YOLO Detection",
172
+ "detections": detections,
173
+ "annotated_image_url": f"/outputs/{output_filename}"
174
+ }
175
+
176
+ elif model_name == "mwt":
177
+ tensor = preprocess_for_mwt(image_np)
178
+ with torch.no_grad():
179
+ output = mwt_model(tensor.to(device)).cpu()
180
+ probs = torch.softmax(output, dim=1)[0]
181
+ confidences = {mwt_class_names[i]: float(probs[i]) for i in range(2)}
182
+ predicted_label = mwt_class_names[torch.argmax(probs)]
183
+ return {"model_used": "MWT Classifier", "prediction": predicted_label, "confidence": confidences}
184
+
185
+ elif model_name == "cin":
186
+ nparr = np.frombuffer(contents, np.uint8)
187
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
188
+ results = yolo_colposcopy.predict(source=img, conf=0.7, save=False, verbose=False)
189
+ if len(results[0].boxes) == 0:
190
+ return {"error": "No cervix detected"}
191
+ x1, y1, x2, y2 = map(int, results[0].boxes.xyxy[0].cpu().numpy())
192
+ crop = img[y1:y2, x1:x2]
193
+ crop = cv2.resize(crop, (224, 224))
194
+ img_t = transform(crop).unsqueeze(0).to(device)
195
+ f50 = extract_cbf_features(resnet50_blocks, img_t)
196
+ f101 = extract_cbf_features(resnet101_blocks, img_t)
197
+ f152 = extract_cbf_features(resnet152_blocks, img_t)
198
+ features = np.concatenate([f50, f101, f152]).reshape(1, -1)
199
+ X_scaled = MinMaxScaler().fit_transform(features)
200
+ pred = clf.predict(X_scaled)[0]
201
+ proba = clf.predict_proba(X_scaled)[0]
202
+ classes = ["CIN1", "CIN2", "CIN3"]
203
+ return {
204
+ "model_used": "CIN Classifier",
205
+ "prediction": classes[pred],
206
+ "probabilities": dict(zip(classes, map(float, proba)))
207
+ }
208
+
209
+ elif model_name == "histopathology":
210
+ result = predict_histopathology(image)
211
+ return result
212
+
213
+ else:
214
+ return JSONResponse(content={"error": "Invalid model name"}, status_code=400)
215
+
216
+
217
+ @app.get("/models")
218
+ def get_models():
219
+ return {"available_models": ["yolo", "mwt", "cin", "histopathology"]}
220
+
221
+
222
+ @app.get("/health")
223
+ def health():
224
+ return {"message": "Unified Cervical & Breast Cancer API is running!"}
225
+
226
+ # After other app.mount()s
227
+ app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
228
+ app.mount("/assets", StaticFiles(directory="frontend/dist/assets"), name="assets")
229
+ from fastapi.staticfiles import StaticFiles
230
+
231
+ app.mount("/", StaticFiles(directory="frontend/dist", html=True), name="static")
232
+
233
+
234
+ @app.get("/")
235
+ async def serve_frontend():
236
+ index_path = os.path.join("frontend", "dist", "index.html")
237
+ return FileResponse(index_path)
238
+
239
+ if __name__ == "__main__":
240
+ uvicorn.run(app, host="0.0.0.0", port=7860)
241
+
backend/augmentations.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ This file is a part of project "Aided-Diagnosis-System-for-Cervical-Cancer-Screening".
4
+ See https://github.com/ShenghuaCheng/Aided-Diagnosis-System-for-Cervical-Cancer-Screening for more information.
5
+ File name: augmentations.py
6
+ Description: augmentation functions.
7
+ """
8
+
9
+ import functools
10
+ import random
11
+ import cv2
12
+ import numpy as np
13
+ from skimage.exposure import adjust_gamma
14
+ from loguru import logger
15
+
16
+ __all__ = [
17
+ "Augmentations",
18
+ "StylisticTrans",
19
+ "SpatialTrans",
20
+ ]
21
+
22
+
23
+ class Augmentations:
24
+ """
25
+ All parameters in each augmentations have been fixed to a suitable range.
26
+ img = [size, size, ch]
27
+ ch = 3: only img
28
+ ch = 4: img with mask at 4th dim
29
+ """
30
+
31
+ @staticmethod
32
+ def Compose(*funcs):
33
+ funcs = list(funcs)
34
+ func_names = [f.__name__ for f in funcs]
35
+ # ensure the norm opt is the last opt
36
+ if 'norm' in func_names:
37
+ idx = func_names.index('norm')
38
+ funcs = funcs[:idx] + funcs[idx:] + [funcs[idx]]
39
+
40
+ def compose(img: np.ndarray):
41
+ return functools.reduce(lambda f, g: lambda x: g(f(x)), funcs)(img)
42
+
43
+ return compose
44
+
45
+ """
46
+ # ===========================================================================================================
47
+ # random stylistic augmentations
48
+ """
49
+
50
+ @staticmethod
51
+ def RandomGamma(p: float = 0.5):
52
+ def random_gamma(img: np.ndarray):
53
+ if random.random() < p:
54
+ gamma = 0.6 + random.random() * 0.6
55
+ img[..., :3] = StylisticTrans.gamma_adjust(img[..., :3], gamma)
56
+ return img
57
+
58
+ return random_gamma
59
+
60
+ @staticmethod
61
+ def RandomSharp(p: float = 0.5):
62
+ def random_sharp(img: np.ndarray):
63
+ if random.random() < p:
64
+ sigma = 8.3 + random.random() * 0.4
65
+ img[..., :3] = StylisticTrans.sharp(img[..., :3], sigma)
66
+ return img
67
+
68
+ return random_sharp
69
+
70
+ @staticmethod
71
+ def RandomGaussainBlur(p: float = 0.5):
72
+ def random_gaussian_blur(img: np.ndarray):
73
+ if random.random() < p:
74
+ sigma = 0.1 + random.random() * 1
75
+ img[..., :3] = StylisticTrans.gaussian_blur(img[..., :3], sigma)
76
+ return img
77
+
78
+ return random_gaussian_blur
79
+
80
+ @staticmethod
81
+ def RandomHSVDisturb(p: float = 0.5):
82
+ def random_hsv_disturb(img: np.ndarray):
83
+ if random.random() < p:
84
+ k = np.random.random(3) * [0.1, 0.8, 0.45] + [0.95, 0.7, 0.75]
85
+ b = np.random.random(3) * [6, 20, 18] + [-3, -10, -10]
86
+ img[..., :3] = StylisticTrans.hsv_disturb(img[..., :3], k.tolist(), b.tolist())
87
+ return img
88
+
89
+ return random_hsv_disturb
90
+
91
+ @staticmethod
92
+ def RandomRGBSwitch(p: float = 0.5):
93
+ def random_rgb_switch(img: np.ndarray):
94
+ if random.random() < p:
95
+ bgr_seq = list(range(3))
96
+ random.shuffle(bgr_seq)
97
+ img[..., :3] = StylisticTrans.bgr_switch(img[..., :3], bgr_seq)
98
+ return img
99
+
100
+ return random_rgb_switch
101
+
102
+ """
103
+ # ===========================================================================================================
104
+ # random spatial augmentations, funcs can be implement to tiles and their masks.
105
+ """
106
+
107
+ @staticmethod
108
+ def RandomRotate90(p: float = 0.5):
109
+ def random_rotate90(img: np.ndarray):
110
+ if random.random() < p:
111
+ angle = 90 * random.randint(1, 3)
112
+ img = SpatialTrans.rotate(img, angle)
113
+ return img
114
+
115
+ return random_rotate90
116
+
117
+ @staticmethod
118
+ def RandomHorizontalFlip(p: float = 0.5):
119
+ def random_horizontal_flip(img: np.ndarray):
120
+ if random.random() < p:
121
+ img = SpatialTrans.flip(img, 0)
122
+ return img
123
+
124
+ return random_horizontal_flip
125
+
126
+ @staticmethod
127
+ def RandomVerticalFlip(p: float = 0.5):
128
+ def random_vertical_flip(img: np.ndarray):
129
+ if random.random() < p:
130
+ img = SpatialTrans.flip(img, 1)
131
+ return img
132
+
133
+ return random_vertical_flip
134
+
135
+ @staticmethod
136
+ def RandomScale(p: float = 0.5):
137
+ def random_scale(img: np.ndarray):
138
+ if random.random() < p:
139
+ ratio = 0.8 + random.random() * 0.4
140
+ img = SpatialTrans.scale(img, ratio, True)
141
+ return img
142
+
143
+ return random_scale
144
+
145
+ @staticmethod
146
+ def RandomCrop(p: float = 1., size: tuple = (512, 512)):
147
+ def random_crop(img: np.ndarray):
148
+ if random.random() < p:
149
+ # for a large FOV, control the translate range
150
+ new_shape = list(img.shape[:2][::-1])
151
+ if img.shape[0] > size[1] * 1.5:
152
+ new_shape[1] = int(size[1] * 1.5)
153
+ if img.shape[1] > size[0] * 1.5:
154
+ new_shape[0] = int(size[0] * 1.5)
155
+ img = SpatialTrans.center_crop(img.copy(), tuple(new_shape))
156
+ # do translate
157
+ xy = np.random.random(2) * (np.array(img.shape[:2]) - list(size))
158
+ bbox = tuple(xy.astype(np.int).tolist() + list(size))
159
+ img = SpatialTrans.crop(img, bbox)
160
+ else:
161
+ img = SpatialTrans.center_crop(img, size)
162
+ return img
163
+
164
+ return random_crop
165
+
166
+ @staticmethod
167
+ def Normalization(rng: list = [-1, 1]):
168
+ def norm(img: np.ndarray):
169
+ img = StylisticTrans.normalization(img, rng)
170
+ return img
171
+
172
+ return norm
173
+
174
+ @staticmethod
175
+ def CenterCrop(size: tuple = (512, 512)):
176
+ def center_crop(img: np.ndarray):
177
+ img = SpatialTrans.center_crop(img, size)
178
+ return img
179
+
180
+ return center_crop
181
+
182
+
183
+ class StylisticTrans:
184
+ # TODO Some implementations of augmentation need a efficient way
185
+ """
186
+ set of augmentations applied to the content of image
187
+ """
188
+
189
+ @staticmethod
190
+ def gamma_adjust(img: np.ndarray, gamma: float):
191
+ """ adjust gamma
192
+ :param img: a ndarray, better a BGR
193
+ :param gamma: gamma, recommended value 0.6, range [0.6, 1.2]
194
+ :return: a ndarray
195
+ """
196
+ return adjust_gamma(img.copy(), gamma)
197
+
198
+ @staticmethod
199
+ def sharp(img: np.ndarray, sigma: float):
200
+ """sharp image
201
+ :param img: a ndarray, better a BGR
202
+ :param sigma: sharp degree, recommended range [8.3, 8.7]
203
+ :return: a ndarray
204
+ """
205
+ kernel = np.array([[-1, -1, -1], [-1, sigma, -1], [-1, -1, -1]], np.float32) / (sigma - 8) # 锐化
206
+ return cv2.filter2D(img.copy(), -1, kernel=kernel)
207
+
208
+ @staticmethod
209
+ def gaussian_blur(img: np.ndarray, sigma: float):
210
+ """blurring image
211
+ :param img: a ndarray, better a BGR
212
+ :param sigma: blurring degree, recommended range [0.1, 1.1]
213
+ :return: a ndarray
214
+ """
215
+ return cv2.GaussianBlur(img.copy(), (int(6 * np.ceil(sigma) + 1), int(6 * np.ceil(sigma) + 1)), sigma)
216
+
217
+ @staticmethod
218
+ def hsv_disturb(img: np.ndarray, k: list, b: list):
219
+ """ disturb the hsv value
220
+ :param img: a BGR ndarray
221
+ :param k: low_b = [0.95, 0.7, 0.75] ,upper_b = [1.05, 1.5, 1.2]
222
+ :param b: low_b = [-3, -10, -10] ,upper_b = [3, 10, 8]
223
+ :return: a BGR ndarray
224
+ """
225
+ img = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2HSV)
226
+ img = img.astype(np.float)
227
+ for ch in range(3):
228
+ img[..., ch] = k[ch] * img[..., ch] + b[ch]
229
+ img = np.uint8(np.clip(img, np.array([0, 1, 1]), np.array([180, 255, 255])))
230
+ return cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
231
+
232
+ @staticmethod
233
+ def bgr_switch(img: np.ndarray, bgr_seq: list):
234
+ """ switch bgr
235
+ :param img: a ndarray, better a BGR
236
+ :param bgr_seq: new ch seq
237
+ :return: a ndarray
238
+ """
239
+ return img.copy()[..., bgr_seq]
240
+
241
+ @staticmethod
242
+ def normalization(img: np.ndarray, rng: list):
243
+ """normalize image according to min and max
244
+ :param img: a ndarray
245
+ :param rng: normalize image value to range[min, max]
246
+ :return: a ndarray
247
+ """
248
+ lb, ub = rng
249
+ delta = ub - lb
250
+ return (img.copy().astype(np.float64) / 255.) * delta + lb#yjx
251
+
252
+
253
+ class SpatialTrans:
254
+ """
255
+ set of augmentations applied to the spatial space of image
256
+ """
257
+
258
+ @staticmethod
259
+ def rotate(img: np.ndarray, angle: int):
260
+ """ rotate image
261
+ # todo Square image and central rotate only, a universal version is needed
262
+ :param img: a ndarray
263
+ :param angle: rotate angle
264
+ :return: a ndarray has same size as input, padding zero or cut out region out of picture
265
+ """
266
+ assert img.shape[0] == img.shape[1], "Square image needed."
267
+ center = (img.shape[0]/2, img.shape[1]/2)
268
+ mat = cv2.getRotationMatrix2D(center, angle, scale=1)
269
+ # mat = cv2.getRotationMatrix2D(tuple(np.array(img.shape[:2]) // 2), angle, scale=1)
270
+ return cv2.warpAffine(img.copy(), mat, img.shape[:2])
271
+
272
+ @staticmethod
273
+ def flip(img: np.ndarray, flip_axis: int):
274
+ """flip image horizontal or vertical
275
+ :param img: a ndarray
276
+ :param flip_axis: 0 for horizontal, 1 for vertical
277
+ :return: a flipped image
278
+ """
279
+ return cv2.flip(img.copy(), flip_axis)
280
+
281
+ @staticmethod
282
+ def scale(img: np.ndarray, ratio: float, fix_size: bool = False):
283
+ """scale image
284
+ :param img: a ndarray
285
+ :param ratio: scale ratio
286
+ :param fix_size: return the center area of scaled image, size of area is same as the image before scaling
287
+ :return: a scaled image
288
+ """
289
+ shape = img.shape[:2][::-1]
290
+ img = cv2.resize(img.copy(), None, fx=ratio, fy=ratio)
291
+ if fix_size:
292
+ img = SpatialTrans.center_crop(img, shape)
293
+ return img
294
+
295
+ @staticmethod
296
+ def crop(img: np.ndarray, bbox: tuple):
297
+ """crop image according to given bbox
298
+ :param img: a ndarray
299
+ :param bbox: bbox of cropping area (x, y, w, h)
300
+ :return: cropped image,padding with zeros
301
+ """
302
+ ch = [] if len(img.shape) == 2 else [img.shape[-1]]
303
+ template = np.zeros(list(bbox[-2:])[::-1] + ch)
304
+
305
+ if (bbox[1] >= img.shape[0] or bbox[1] >= img.shape[1]) or (bbox[0] + bbox[2] <= 0 or bbox[1] + bbox[3] <= 0):
306
+ logger.warning("Crop area contains nothing, return a zeros array {}".format(template.shape))
307
+ return template
308
+
309
+ foreground = img[
310
+ np.maximum(bbox[1], 0): np.minimum(bbox[1] + bbox[3], img.shape[0]),
311
+ np.maximum(bbox[0], 0): np.minimum(bbox[0] + bbox[2], img.shape[1]), :]
312
+
313
+ template[
314
+ np.maximum(-bbox[1], 0): np.minimum(-bbox[1] + img.shape[0], bbox[3]),
315
+ np.maximum(-bbox[0], 0): np.minimum(-bbox[0] + img.shape[1], bbox[2]), :] = foreground
316
+ return template.astype(np.uint8)
317
+
318
+ @staticmethod
319
+ def center_crop(img: np.ndarray, shape: tuple):
320
+ """return the center area in shape
321
+ :param img: a ndarray
322
+ :param shape: center crop shape (w, h)
323
+ :return:
324
+ """
325
+ center = np.array(img.shape[:2]) // 2
326
+ init = center[::-1] - np.array(shape) // 2
327
+ bbox = tuple(init.tolist() + list(shape))
328
+ return SpatialTrans.crop(img, bbox)
backend/best2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dc58852bdf5287554846eacd4d27daf757dbbcb111cec21e7df2bc463401e5e
3
+ size 6236202
backend/histopathology_trained_model.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62e427f5d52567b488b157fa1aa8e3ef8236434b1b3752ca49d8a46c144d90c1
3
+ size 8069694
backend/logistic_regression_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09c3442ef1cec5614ed17bbd9e1a4a1f8e38c588fdea13b2171b6662ace86c7b
3
+ size 94559
backend/model.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint as checkpoint
5
+ import numpy as np
6
+ from typing import Optional
7
+ from thop import profile
8
+
9
+ class IRB(nn.Module):
10
+ def __init__(self, in_features, hidden_features=None, out_features=None, ksize=3, act_layer=nn.Hardswish, drop=0.):
11
+ super().__init__()
12
+ out_features = out_features or in_features
13
+ hidden_features = hidden_features or in_features
14
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0)
15
+ self.act = act_layer()
16
+ self.conv = nn.Conv2d(hidden_features, hidden_features, kernel_size=ksize, padding=ksize // 2, stride=1,
17
+ groups=hidden_features)
18
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0)
19
+ self.drop = nn.Dropout(drop)
20
+
21
+ def forward(self, x, H, W):
22
+ B, N, C = x.shape
23
+ x = x.permute(0, 2, 1).reshape(B, C, H, W)
24
+ x = self.fc1(x)
25
+ x = self.act(x)
26
+ x = self.conv(x)
27
+ x = self.act(x)
28
+ x = self.fc2(x)
29
+ return x.reshape(B, C, -1).permute(0, 2, 1)
30
+
31
+
32
+ def drop_path_f(x, drop_prob: float = 0., training: bool = False):
33
+
34
+ if drop_prob == 0. or not training:
35
+ return x
36
+ keep_prob = 1 - drop_prob
37
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
38
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
39
+ random_tensor.floor_() # binarize
40
+ output = x.div(keep_prob) * random_tensor
41
+ return output
42
+
43
+
44
+ class DropPath(nn.Module):
45
+
46
+ def __init__(self, drop_prob=None):
47
+ super(DropPath, self).__init__()
48
+ self.drop_prob = drop_prob
49
+
50
+ def forward(self, x):
51
+ return drop_path_f(x, self.drop_prob, self.training)
52
+
53
+
54
+ def window_partition(x, window_size: int):
55
+ B, H, W, C = x.shape
56
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
57
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
58
+ return windows
59
+
60
+
61
+ def window_reverse(windows, window_size: int, H: int, W: int):
62
+
63
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
64
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
65
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
66
+ return x
67
+
68
+
69
+ class PatchEmbed2(nn.Module):
70
+ def __init__(self, dim:int, patch_size=2, in_c=3, norm_layer=None):
71
+ super().__init__()
72
+ patch_size = (patch_size, patch_size)
73
+ self.patch_size = patch_size
74
+ self.in_chans = in_c
75
+ self.embed_dim = dim
76
+ self.proj = nn.Conv2d(dim, 2*dim, kernel_size=patch_size, stride=patch_size)
77
+ self.norm = norm_layer(2*dim) if norm_layer else nn.Identity()
78
+
79
+ def forward(self, x, H, W):
80
+ B, L, C = x.shape
81
+ assert L == H * W, "input feature has wrong size"
82
+
83
+ x = x.view(B, H, W, C)
84
+ pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
85
+ if pad_input:
86
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
87
+ 0, self.patch_size[0] - H % self.patch_size[0],
88
+ 0, 0))
89
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2)
90
+ x = self.proj(x)
91
+ _, _, H, W = x.shape
92
+ x = x.flatten(2).transpose(1, 2)
93
+ x = self.norm(x)
94
+ return x
95
+
96
+
97
+ class PatchEmbed(nn.Module):
98
+ def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
99
+ super().__init__()
100
+ patch_size = (patch_size, patch_size)
101
+ self.patch_size = patch_size
102
+ self.in_chans = in_c
103
+ self.embed_dim = embed_dim
104
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
105
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
106
+
107
+ def forward(self, x):
108
+ _, _, H, W = x.shape
109
+
110
+ # padding
111
+ # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
112
+ pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
113
+ if pad_input:
114
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
115
+ 0, self.patch_size[0] - H % self.patch_size[0],
116
+ 0, 0))
117
+
118
+ # 下采样patch_size倍
119
+ x = self.proj(x)
120
+ _, _, H, W = x.shape
121
+ x = x.flatten(2).transpose(1, 2)
122
+ x = self.norm(x)
123
+ return x, H, W
124
+
125
+
126
+ class PatchMerging(nn.Module):
127
+
128
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
129
+ super().__init__()
130
+ self.dim = dim
131
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
132
+ self.norm = norm_layer(4 * dim)
133
+
134
+ def forward(self, x, H, W):
135
+ B, L, C = x.shape
136
+ assert L == H * W, "input feature has wrong size"
137
+
138
+ x = x.view(B, H, W, C)
139
+
140
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
141
+ if pad_input:
142
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
143
+
144
+ x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
145
+ x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
146
+ x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
147
+ x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
148
+ x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]
149
+ x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]
150
+
151
+ x = self.norm(x)
152
+ x = self.reduction(x) # [B, H/2*W/2, 2*C]
153
+
154
+ return x
155
+
156
+
157
+ class Mlp(nn.Module):
158
+
159
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
160
+ super().__init__()
161
+ out_features = out_features or in_features
162
+ hidden_features = hidden_features or in_features
163
+
164
+ self.fc1 = nn.Linear(in_features, hidden_features)
165
+ self.act = act_layer()
166
+ self.drop1 = nn.Dropout(drop)
167
+ self.fc2 = nn.Linear(hidden_features, out_features)
168
+ self.drop2 = nn.Dropout(drop)
169
+
170
+ def forward(self, x):
171
+ x = self.fc1(x)
172
+ x = self.act(x)
173
+ x = self.drop1(x)
174
+ x = self.fc2(x)
175
+ x = self.drop2(x)
176
+ return x
177
+
178
+
179
+ class WindowAttention(nn.Module):
180
+
181
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
182
+
183
+ super().__init__()
184
+ self.dim = dim
185
+ self.window_size = window_size # [Mh, Mw]
186
+ self.num_heads = num_heads
187
+ head_dim = dim // num_heads
188
+ self.scale = head_dim ** -0.5
189
+
190
+ # define a parameter table of relative position bias
191
+ self.relative_position_bias_table = nn.Parameter(
192
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # [2*Mh-1 * 2*Mw-1, nH]
193
+
194
+ # get pair-wise relative position index for each token inside the window
195
+ coords_h = torch.arange(self.window_size[0])
196
+ coords_w = torch.arange(self.window_size[1])
197
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
198
+ coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw]
199
+ # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
200
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw]
201
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh*Mw, Mh*Mw, 2]
202
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
203
+ relative_coords[:, :, 1] += self.window_size[1] - 1
204
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
205
+ relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw]
206
+ self.register_buffer("relative_position_index", relative_position_index)
207
+
208
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
209
+ self.attn_drop = nn.Dropout(attn_drop)
210
+ self.proj = nn.Linear(dim, dim)
211
+ self.proj_drop = nn.Dropout(proj_drop)
212
+
213
+ nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
214
+ self.softmax = nn.Softmax(dim=-1)
215
+
216
+ def forward(self, x, mask: Optional[torch.Tensor] = None):
217
+ B_, N, C = x.shape
218
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
219
+ # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
220
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
221
+
222
+ # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
223
+ # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
224
+ q = q * self.scale
225
+ attn = (q @ k.transpose(-2, -1))
226
+
227
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
228
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
229
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # [nH, Mh*Mw, Mh*Mw]
230
+ attn = attn + relative_position_bias.unsqueeze(0)
231
+
232
+ if mask is not None:
233
+ nW = mask.shape[0] # num_windows
234
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
235
+ attn = attn.view(-1, self.num_heads, N, N)
236
+ attn = self.softmax(attn)
237
+ else:
238
+ attn = self.softmax(attn)
239
+
240
+ attn = self.attn_drop(attn)
241
+
242
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
243
+ x = self.proj(x)
244
+ x = self.proj_drop(x)
245
+ return x
246
+
247
+
248
+ class TransformerBlock(nn.Module):
249
+ def __init__(self, dim, num_heads, window_sizes=(7,4,2), branch_num=3,
250
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
251
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
252
+ super().__init__()
253
+ self.dim = dim
254
+ self.num_heads = num_heads
255
+ self.window_sizes = window_sizes
256
+ self.branch_num = branch_num
257
+ self.mlp_ratio = mlp_ratio
258
+
259
+ self.norm1 = norm_layer(dim)
260
+ self.attn = WindowAttention(
261
+ dim//branch_num, window_size=(self.window_sizes[0], self.window_sizes[0]), num_heads=num_heads//branch_num, qkv_bias=qkv_bias,
262
+ attn_drop=attn_drop, proj_drop=drop)
263
+ self.attn1 = WindowAttention(
264
+ dim//branch_num, window_size=(self.window_sizes[1], self.window_sizes[1]), num_heads=num_heads//branch_num, qkv_bias=qkv_bias,
265
+ attn_drop=attn_drop, proj_drop=drop)
266
+ self.attn2 = WindowAttention(
267
+ dim//branch_num, window_size=(self.window_sizes[2], self.window_sizes[2]), num_heads=num_heads//branch_num, qkv_bias=qkv_bias,
268
+ attn_drop=attn_drop, proj_drop=drop)
269
+
270
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
271
+ self.norm2 = norm_layer(dim)
272
+ mlp_hidden_dim = int(dim * mlp_ratio)
273
+ self.mlp = IRB(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
274
+
275
+ def forward(self, x, attn_mask):
276
+ H, W = self.H, self.W
277
+ B, L, C = x.shape
278
+ assert L == H * W, "input feature has wrong size"
279
+
280
+ shortcut = x
281
+ x = self.norm1(x)
282
+ x = x.view(B, H, W, C)
283
+ x0 = x[:,:,:,:(C//self.branch_num)]
284
+ x1 = x[:,:,:,(C//self.branch_num):(2*C//self.branch_num)]
285
+ x2 = x[:,:,:,(2*C//self.branch_num):]
286
+ # ----------------------------------------------------------------------------------------------
287
+ pad_l = pad_t = 0
288
+ pad_r = (self.window_sizes[0] - W % self.window_sizes[0]) % self.window_sizes[0]
289
+ pad_b = (self.window_sizes[0] - H % self.window_sizes[0]) % self.window_sizes[0]
290
+ x0 = F.pad(x0, (0, 0, pad_l, pad_r, pad_t, pad_b))
291
+ _, Hp, Wp, _ = x0.shape
292
+ attn_mask = None
293
+
294
+ # partition windows
295
+ x_windows = window_partition(x0, self.window_sizes[0]) # [nW*B, Mh, Mw, C]
296
+ x_windows = x_windows.view(-1, self.window_sizes[0] * self.window_sizes[0], C//self.branch_num) # [nW*B, Mh*Mw, C]
297
+
298
+ # W-MSA/SW-MSA
299
+ attn_windows = self.attn(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C]
300
+
301
+ # merge windows
302
+ attn_windows = attn_windows.view(-1, self.window_sizes[0], self.window_sizes[0], C//self.branch_num) # [nW*B, Mh, Mw, C]
303
+ x0 = window_reverse(attn_windows, self.window_sizes[0], Hp, Wp) # [B, H', W', C]
304
+
305
+ if pad_r > 0 or pad_b > 0:
306
+ # 把前面pad的数据移除掉
307
+ x0 = x0[:, :H, :W, :].contiguous()
308
+
309
+ x0 = x0.view(B, H * W, C//self.branch_num)
310
+ # ----------------------------------------------------------------------------------------------
311
+ pad_l = pad_t = 0
312
+ pad_r = (self.window_sizes[1] - W % self.window_sizes[1]) % self.window_sizes[1]
313
+ pad_b = (self.window_sizes[1] - H % self.window_sizes[1]) % self.window_sizes[1]
314
+ x1 = F.pad(x1, (0, 0, pad_l, pad_r, pad_t, pad_b))
315
+ _, Hp, Wp, _ = x1.shape
316
+ attn_mask = None
317
+
318
+ # partition windows
319
+ x_windows = window_partition(x1, self.window_sizes[1]) # [nW*B, Mh, Mw, C]
320
+ x_windows = x_windows.view(-1, self.window_sizes[1] * self.window_sizes[1],
321
+ C // self.branch_num) # [nW*B, Mh*Mw, C]
322
+
323
+ # W-MSA/SW-MSA
324
+ attn_windows = self.attn1(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C]
325
+
326
+ # merge windows
327
+ attn_windows = attn_windows.view(-1, self.window_sizes[1], self.window_sizes[1],
328
+ C // self.branch_num) # [nW*B, Mh, Mw, C]
329
+ x1 = window_reverse(attn_windows, self.window_sizes[1], Hp, Wp) # [B, H', W', C]
330
+
331
+ if pad_r > 0 or pad_b > 0:
332
+ # 把前面pad的数据移除掉
333
+ x1 = x1[:, :H, :W, :].contiguous()
334
+ x1 = x1.view(B, H * W, C // self.branch_num)
335
+ # ----------------------------------------------------------------------------------------------
336
+ pad_l = pad_t = 0
337
+ pad_r = (self.window_sizes[2] - W % self.window_sizes[2]) % self.window_sizes[2]
338
+ pad_b = (self.window_sizes[2] - H % self.window_sizes[2]) % self.window_sizes[2]
339
+ x2 = F.pad(x2, (0, 0, pad_l, pad_r, pad_t, pad_b))
340
+ _, Hp, Wp, _ = x2.shape
341
+ attn_mask = None
342
+ x_windows = window_partition(x2, self.window_sizes[2]) # [nW*B, Mh, Mw, C]
343
+ x_windows = x_windows.view(-1, self.window_sizes[2] * self.window_sizes[2],
344
+ C // self.branch_num) # [nW*B, Mh*Mw, C]
345
+
346
+ attn_windows = self.attn2(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C]
347
+
348
+ attn_windows = attn_windows.view(-1, self.window_sizes[2], self.window_sizes[2],
349
+ C // self.branch_num) # [nW*B, Mh, Mw, C]
350
+ x2 = window_reverse(attn_windows, self.window_sizes[2], Hp, Wp) # [B, H', W', C]
351
+
352
+ if pad_r > 0 or pad_b > 0:
353
+ x2 = x2[:, :H, :W, :].contiguous()
354
+
355
+ x2 = x2.view(B, H * W, C // self.branch_num)
356
+
357
+ x = torch.cat([x0, x1, x2], -1)
358
+ # FFN
359
+ x = shortcut + self.drop_path(x)
360
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
361
+
362
+ return x
363
+
364
+
365
+ class BasicLayer(nn.Module):
366
+ def __init__(self, dim, depth, num_heads, window_size,
367
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
368
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
369
+ super().__init__()
370
+ self.dim = dim
371
+ self.depth = depth
372
+ self.window_size = window_size
373
+ self.use_checkpoint = use_checkpoint
374
+ self.shift_size = window_size // 2
375
+
376
+ # build blocks
377
+ self.blocks = nn.ModuleList([
378
+ TransformerBlock(
379
+ dim=dim,
380
+ num_heads=num_heads,
381
+ window_sizes=(7,4,2),
382
+ mlp_ratio=mlp_ratio,
383
+ qkv_bias=qkv_bias,
384
+ drop=drop,
385
+ attn_drop=attn_drop,
386
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
387
+ norm_layer=norm_layer)
388
+ for i in range(depth)])
389
+
390
+ # patch merging layer
391
+ if downsample is not None:
392
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
393
+ else:
394
+ self.downsample = None
395
+
396
+ def create_mask(self, x, H, W):
397
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
398
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
399
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # [1, Hp, Wp, 1]
400
+ h_slices = (slice(0, -self.window_size),
401
+ slice(-self.window_size, -self.shift_size),
402
+ slice(-self.shift_size, None))
403
+ w_slices = (slice(0, -self.window_size),
404
+ slice(-self.window_size, -self.shift_size),
405
+ slice(-self.shift_size, None))
406
+ cnt = 0
407
+ for h in h_slices:
408
+ for w in w_slices:
409
+ img_mask[:, h, w, :] = cnt
410
+ cnt += 1
411
+
412
+ mask_windows = window_partition(img_mask, self.window_size) # [nW, Mh, Mw, 1]
413
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw]
414
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
415
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
416
+ return attn_mask
417
+
418
+ def forward(self, x, H, W):
419
+ attn_mask = self.create_mask(x, H, W) # [nW, Mh*Mw, Mh*Mw]
420
+ for blk in self.blocks:
421
+ blk.H, blk.W = H, W
422
+ if not torch.jit.is_scripting() and self.use_checkpoint:
423
+ x = checkpoint.checkpoint(blk, x, attn_mask)
424
+ else:
425
+ x = blk(x, attn_mask)
426
+ if self.downsample is not None:
427
+ x = self.downsample(x, H, W)
428
+ H, W = (H + 1) // 2, (W + 1) // 2
429
+
430
+ return x, H, W
431
+
432
+
433
+ class Transformer(nn.Module):
434
+ def __init__(self, patch_size=4, in_chans=3, num_classes=1000,
435
+ embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
436
+ window_size=7, mlp_ratio=4., qkv_bias=True,
437
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
438
+ norm_layer=nn.LayerNorm, patch_norm=True,
439
+ use_checkpoint=False, **kwargs):
440
+ super().__init__()
441
+
442
+ self.num_classes = num_classes
443
+ self.num_layers = len(depths)
444
+ self.embed_dim = embed_dim
445
+ self.patch_norm = patch_norm
446
+ # stage4输出特征矩阵的channels
447
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
448
+ self.mlp_ratio = mlp_ratio
449
+
450
+ self.patch_embed = PatchEmbed(
451
+ patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
452
+ norm_layer=norm_layer if self.patch_norm else None)
453
+ self.pos_drop = nn.Dropout(p=drop_rate)
454
+
455
+ # stochastic depth
456
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
457
+
458
+ self.layers = nn.ModuleList()
459
+ for i_layer in range(self.num_layers):
460
+ layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
461
+ depth=depths[i_layer],
462
+ num_heads=num_heads[i_layer],
463
+ window_size=window_size,
464
+ mlp_ratio=self.mlp_ratio,
465
+ qkv_bias=qkv_bias,
466
+ drop=drop_rate,
467
+ attn_drop=attn_drop_rate,
468
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
469
+ norm_layer=norm_layer,
470
+ downsample=PatchEmbed2 if (i_layer < self.num_layers - 1) else None,
471
+ use_checkpoint=use_checkpoint)
472
+ self.layers.append(layers)
473
+
474
+ self.norm = norm_layer(self.num_features)
475
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
476
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
477
+
478
+ self.apply(self._init_weights)
479
+
480
+ def _init_weights(self, m):
481
+ if isinstance(m, nn.Linear):
482
+ nn.init.trunc_normal_(m.weight, std=.02)
483
+ if isinstance(m, nn.Linear) and m.bias is not None:
484
+ nn.init.constant_(m.bias, 0)
485
+ elif isinstance(m, nn.LayerNorm):
486
+ nn.init.constant_(m.bias, 0)
487
+ nn.init.constant_(m.weight, 1.0)
488
+
489
+ def forward(self, x):
490
+ # x: [B, L, C]
491
+ x, H, W = self.patch_embed(x)
492
+ x = self.pos_drop(x)
493
+
494
+ for layer in self.layers:
495
+ x, H, W = layer(x, H, W)
496
+
497
+ x = self.norm(x) # [B, L, C]
498
+ x = self.avgpool(x.transpose(1, 2)) # [B, C, 1]
499
+ x = torch.flatten(x, 1)
500
+ x = self.head(x)
501
+ return x
502
+
503
+
504
+ def MWT(num_classes: int = 1000, **kwargs):
505
+ model = Transformer(in_chans=3,
506
+ patch_size=4,
507
+ # window_sizes=(7,4,2),
508
+ embed_dim=96,
509
+ depths=(2, 4, 4, 2),
510
+ num_heads=(3, 6, 12, 24),
511
+ num_classes=num_classes,
512
+ **kwargs)
513
+ return model
514
+
515
+ if __name__ == '__main__':
516
+ model = MWT(num_classes=2)
517
+ input = torch.randn(1, 3, 224, 224)
518
+ flops, params = profile(model, inputs=(input,))
519
+ print(flops)
520
+ print(params)
521
+
backend/model_histo.py ADDED
@@ -0,0 +1,1495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Breast Cancer Histopathology Classification using Path Foundation Model
3
+
4
+ This module implements a comprehensive deep learning pipeline for breast cancer classification
5
+ from histopathology images using Google's Path Foundation model as a feature extractor. The
6
+ system supports multiple datasets including BreakHis, PatchCamelyon (PCam), and BACH, employing
7
+ transfer learning to achieve high classification accuracy.
8
+
9
+ **Overview:**
10
+ This system leverages Google's Path Foundation model, which is pre-trained on a large corpus
11
+ of pathology images, to extract meaningful features from breast cancer histopathology images.
12
+ The approach uses transfer learning where the foundation model serves as a frozen feature
13
+ extractor, followed by a trainable classification head for binary classification (benign vs malignant).
14
+
15
+ **Model Architecture:**
16
+ - Foundation Model: Google's Path Foundation (pre-trained on pathology images)
17
+ - Transfer Learning Approach: Feature extraction with frozen foundation model + trainable classifier head
18
+ - Classification Head: Multi-layer dense network with regularisation and dropout
19
+ - Optimisation: AdamW optimiser with learning rate scheduling and early stopping
20
+
21
+ **Workflow:**
22
+ 1. Authentication & Model Loading: Authenticate with Hugging Face and load Path Foundation
23
+ 2. Data Loading: Load and preprocess histopathology datasets
24
+ 3. Feature Extraction: Extract embeddings using frozen foundation model
25
+ 4. Classifier Training: Train dense neural network on extracted features
26
+ 5. Evaluation: Comprehensive performance analysis with multiple metrics and visualisations
27
+
28
+ **Supported Datasets:**
29
+ - BreakHis: Breast cancer histopathology images at multiple magnifications
30
+ - PatchCamelyon (PCam): Lymph node metastasis detection patches
31
+ - BACH: ICIAR 2018 Breast Cancer Histology Challenge dataset
32
+ - Combined: Ensemble of all three datasets for robust training
33
+
34
+ **Key Features:**
35
+ - Multiple dataset support with consistent pre-processing
36
+ - Robust error handling and fallback mechanisms
37
+ - Comprehensive evaluation metrics and visualisation
38
+ - Memory-efficient batch processing
39
+ - Data augmentation capabilities
40
+ - Model persistence and deployment support
41
+
42
+ Author: Research Team
43
+ Date: 2024
44
+ License: MIT
45
+ """
46
+
47
+ # Import required libraries and configure environment
48
+ import os
49
+ import tensorflow as tf
50
+ import numpy as np
51
+ from PIL import Image
52
+ from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
53
+ from pathlib import Path
54
+ import h5py
55
+ from sklearn.model_selection import train_test_split
56
+ from sklearn.utils.class_weight import compute_class_weight
57
+ from tensorflow.keras import regularizers
58
+ import matplotlib
59
+ # Use a non-interactive backend to prevent blocking on plt.show()
60
+ matplotlib.use('Agg')
61
+ import matplotlib.pyplot as plt
62
+ import seaborn as sns
63
+
64
+ # Suppress TensorFlow logging for cleaner output
65
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
66
+
67
+ # Configure TensorFlow logging for cleaner output
68
+ try:
69
+ tf.get_logger().setLevel('ERROR')
70
+ except AttributeError:
71
+ import logging
72
+ logging.getLogger('tensorflow').setLevel(logging.ERROR)
73
+
74
+ # Configure Hugging Face Hub integration with fallback mechanisms
75
+ # This section handles the loading of Google's Path Foundation model from Hugging Face Hub
76
+ # with multiple fallback methods to ensure compatibility across different environments
77
+ try:
78
+ from huggingface_hub import login, hf_hub_download, snapshot_download
79
+
80
+ # Try different methods for loading Keras models from HF Hub
81
+ # Method 1: Direct Keras loading (preferred)
82
+ try:
83
+ from huggingface_hub import from_pretrained_keras
84
+ KERAS_METHOD = "from_pretrained_keras"
85
+ except ImportError:
86
+ # Method 2: Transformers library fallback
87
+ try:
88
+ from transformers import TFAutoModel
89
+ KERAS_METHOD = "transformers"
90
+ except ImportError:
91
+ # Method 3: Manual download and TFSMLayer
92
+ KERAS_METHOD = "manual"
93
+
94
+ HF_AVAILABLE = True
95
+ print(f"Hugging Face Hub loaded successfully (method: {KERAS_METHOD})")
96
+ except ImportError as e:
97
+ print(f"Hugging Face Hub unavailable: {e}")
98
+ print("Please install required packages: pip install huggingface_hub transformers")
99
+ HF_AVAILABLE = False
100
+ KERAS_METHOD = None
101
+
102
+ class BreastCancerClassifier:
103
+ """
104
+ A comprehensive breast cancer classification system using Path Foundation model.
105
+
106
+ This class implements a transfer learning approach where Google's Path Foundation
107
+ model serves as a feature extractor, followed by a trainable classification head.
108
+ The system supports both feature extraction (frozen foundation model) and
109
+ fine-tuning approaches for maximum flexibility.
110
+
111
+ The classifier can work with multiple histopathology datasets and provides
112
+ comprehensive evaluation capabilities including confusion matrices, classification
113
+ reports, and performance metrics.
114
+
115
+ Attributes:
116
+ fine_tune (bool): Whether to fine-tune the foundation model or use it frozen
117
+ model (tf.keras.Model): The complete classification model
118
+ path_foundation: The loaded Path Foundation model from Hugging Face Hub
119
+ history: Training history from model.fit() containing loss and accuracy curves
120
+ embedding_dim (int): Dimensionality of extracted embeddings from foundation model
121
+ num_classes (int): Number of output classes (default: 2 for binary classification)
122
+
123
+ Example:
124
+ >>> classifier = BreastCancerClassifier(fine_tune=False)
125
+ >>> classifier.authenticate_huggingface()
126
+ >>> classifier.load_path_foundation()
127
+ >>> # Load data and train...
128
+ """
129
+
130
+ def __init__(self, fine_tune=False):
131
+ """
132
+ Initialise the breast cancer classifier.
133
+
134
+ Args:
135
+ fine_tune (bool): If True, allows fine-tuning of foundation model.
136
+ If False, uses foundation model as frozen feature extractor.
137
+
138
+ Note: Fine-tuning requires more computational resources and
139
+ may lead to overfitting on smaller datasets. Feature extraction
140
+ (fine_tune=False) is recommended for most use-cases.
141
+ """
142
+ self.fine_tune = fine_tune
143
+ self.model = None
144
+ self.path_foundation = None
145
+ self.history = None
146
+ self.embedding_dim = None
147
+ self.num_classes = 2 # Binary classification: benign vs malignant
148
+
149
+ def authenticate_huggingface(self, token=None):
150
+ """
151
+ Authenticate with Hugging Face Hub to access Path Foundation model.
152
+
153
+ This method handles authentication with Hugging Face Hub, which is required
154
+ to download and use Google's Path Foundation model. It supports multiple
155
+ token sources and provides fallback mechanisms.
156
+
157
+ Args:
158
+ token (str, optional): Hugging Face access token. If None, the method
159
+ will attempt to use environment variables:
160
+ - HF_TOKEN
161
+ - HUGGINGFACE_HUB_TOKEN
162
+
163
+ Returns:
164
+ bool: True if authentication successful, False otherwise
165
+
166
+ Note:
167
+ You can obtain a Hugging Face token by:
168
+ 1. Creating an account at https://huggingface.co
169
+ 2. Going to Settings > Access Tokens
170
+ 3. Creating a new token with read permissions
171
+
172
+ Example:
173
+ >>> classifier = BreastCancerClassifier()
174
+ >>> success = classifier.authenticate_huggingface("hf_xxxxxxxxxxxx")
175
+ >>> if success:
176
+ ... print("Authentication successful")
177
+ """
178
+ if not HF_AVAILABLE:
179
+ print("Cannot authenticate - Hugging Face Hub not available")
180
+ return False
181
+
182
+ try:
183
+ # Try multiple token sources: parameter, environment variables
184
+ final_token = token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
185
+
186
+ if final_token:
187
+ login(token=final_token, add_to_git_credential=False)
188
+ print("Hugging Face authentication successful")
189
+ return True
190
+ else:
191
+ print("No token provided, attempting to use cached login")
192
+ return True
193
+ except Exception as e:
194
+ print(f"Authentication failed: {e}")
195
+ return False
196
+
197
+ def load_path_foundation(self):
198
+ """
199
+ Load Google's Path Foundation model with multiple fallback mechanisms.
200
+
201
+ This method attempts to load the Path Foundation model using three different
202
+ approaches to ensure maximum compatibility across different environments:
203
+
204
+ 1. Direct Keras loading via huggingface_hub (preferred)
205
+ 2. Transformers library loading (fallback)
206
+ 3. Manual download and TFSMLayer loading (last resort)
207
+
208
+ The method also configures the model's training behavior based on the
209
+ fine_tune parameter set during initialization.
210
+
211
+ Returns:
212
+ bool: True if model loaded successfully, False otherwise
213
+
214
+ Raises:
215
+ Various exceptions may be raised during the loading process, but they
216
+ are caught and handled gracefully with informative error messages.
217
+
218
+ Note:
219
+ The Path Foundation model is a large pre-trained model (~1GB) that will
220
+ be downloaded on first use. Subsequent runs will use the cached version.
221
+
222
+ Example:
223
+ >>> classifier = BreastCancerClassifier(fine_tune=False)
224
+ >>> if classifier.load_path_foundation():
225
+ ... print("Model loaded successfully")
226
+ ... else:
227
+ ... print("Failed to load model")
228
+ """
229
+ if not HF_AVAILABLE:
230
+ print("Cannot load model - Hugging Face Hub unavailable")
231
+ return False
232
+
233
+ try:
234
+ print("Loading Path Foundation model...")
235
+ loaded = False
236
+
237
+ # Method 1: Direct Keras loading (preferred method)
238
+ if KERAS_METHOD == "from_pretrained_keras":
239
+ try:
240
+ self.path_foundation = from_pretrained_keras("google/path-foundation")
241
+ loaded = True
242
+ print("Successfully loaded via from_pretrained_keras")
243
+ except Exception as e:
244
+ print(f"Keras loading failed: {e}")
245
+
246
+ # Method 2: Transformers library fallback
247
+ if not loaded and KERAS_METHOD == "transformers":
248
+ try:
249
+ print("Attempting transformers fallback...")
250
+ self.path_foundation = TFAutoModel.from_pretrained("google/path-foundation")
251
+ loaded = True
252
+ print("Successfully loaded via transformers")
253
+ except Exception as e:
254
+ print(f"Transformers loading failed: {e}")
255
+
256
+ # Method 3: Manual download and TFSMLayer (last resort)
257
+ if not loaded:
258
+ try:
259
+ try:
260
+ import keras as _standalone_keras
261
+ except ImportError as _e:
262
+ print(f"Keras 3 not installed: {_e}")
263
+ return False
264
+
265
+ print("Attempting manual download and TFSMLayer loading...")
266
+ local_dir = snapshot_download(repo_id="google/path-foundation")
267
+ self.path_foundation = _standalone_keras.layers.TFSMLayer(
268
+ local_dir, call_endpoint="serving_default"
269
+ )
270
+ loaded = True
271
+ print("Successfully loaded via TFSMLayer")
272
+ except Exception as e:
273
+ print(f"TFSMLayer loading failed: {e}")
274
+ return False
275
+
276
+ # Configure training behavior based on fine_tune setting
277
+ if self.fine_tune:
278
+ self.path_foundation.trainable = True
279
+ try:
280
+ # Only fine-tune the last 3 layers for stability
281
+ for layer in self.path_foundation.layers[:-3]:
282
+ layer.trainable = False
283
+ print("Fine-tuning enabled: last 3 layers trainable")
284
+ except:
285
+ print("Fine-tuning enabled: full model trainable")
286
+ else:
287
+ self.path_foundation.trainable = False
288
+ print("Model frozen for feature extraction")
289
+
290
+ return True
291
+
292
+ except Exception as e:
293
+ print(f"Failed to load Path Foundation model: {e}")
294
+ return False
295
+
296
+ def preprocess_image_batch(self, images):
297
+ """
298
+ Pre-process a batch of images for Path Foundation model input.
299
+
300
+ This method handles multiple input formats and ensures all images are properly
301
+ formatted for the Path Foundation model. It performs the following operations:
302
+ - Resizes all images to 224x224 pixels (required by Path Foundation)
303
+ - Converts images to RGB format
304
+ - Normalises pixel values to [0, 1] range
305
+ - Handles both file paths and numpy arrays
306
+
307
+ Args:
308
+ images: List or array of images in various formats:
309
+ - File paths (strings) pointing to image files
310
+ - PIL Images
311
+ - NumPy arrays (various shapes and value ranges)
312
+
313
+ Returns:
314
+ np.ndarray: Preprocessed batch of shape (batch_size, 224, 224, 3)
315
+ with pixel values normalized to [0, 1] range
316
+
317
+ Note:
318
+ The method automatically handles different input formats and value ranges.
319
+ Images are resized using PIL's resize method with default interpolation.
320
+
321
+ Example:
322
+ >>> # Process file paths
323
+ >>> image_paths = ['image1.jpg', 'image2.png']
324
+ >>> processed = classifier.preprocess_image_batch(image_paths)
325
+ >>> print(processed.shape) # (2, 224, 224, 3)
326
+
327
+ >>> # Process numpy arrays
328
+ >>> image_arrays = [np.random.rand(100, 100, 3) for _ in range(5)]
329
+ >>> processed = classifier.preprocess_image_batch(image_arrays)
330
+ >>> print(processed.shape) # (5, 224, 224, 3)
331
+ """
332
+ processed = []
333
+
334
+ for img in images:
335
+ if isinstance(img, str):
336
+ # Handle file paths
337
+ img = Image.open(img).convert('RGB')
338
+ img = img.resize((224, 224))
339
+ img_array = np.array(img) / 255.0
340
+ else:
341
+ # Handle numpy arrays
342
+ if img.shape[:2] != (224, 224):
343
+ # Resize if necessary
344
+ if img.max() <= 1:
345
+ img_pil = Image.fromarray((img * 255).astype('uint8'))
346
+ else:
347
+ img_pil = Image.fromarray(img.astype('uint8'))
348
+ img_pil = img_pil.resize((224, 224))
349
+ img_array = np.array(img_pil) / 255.0
350
+ else:
351
+ img_array = img.astype('float32')
352
+ if img_array.max() > 1:
353
+ img_array = img_array / 255.0
354
+
355
+ processed.append(img_array)
356
+
357
+ return np.array(processed)
358
+
359
+ def extract_embeddings(self, images, batch_size=16):
360
+ """
361
+ Extract feature embeddings from images using Path Foundation model.
362
+
363
+ This method processes images in batches to extract high-level feature representations
364
+ using the pre-trained Path Foundation model. The embeddings capture semantic information
365
+ about the histopathology images that can be used for classification.
366
+
367
+ The method handles different model interface types and provides progress tracking
368
+ for large datasets. It automatically determines the embedding dimension on first use.
369
+
370
+ Args:
371
+ images: Array of preprocessed images or list of image paths
372
+ batch_size (int): Number of images to process per batch. Smaller batches
373
+ use less memory but may be slower. Default: 16
374
+
375
+ Returns:
376
+ np.ndarray: Extracted embeddings of shape (num_images, embedding_dim)
377
+ where embedding_dim is determined by the Path Foundation model
378
+
379
+ Raises:
380
+ ValueError: If no embeddings are successfully extracted
381
+ RuntimeError: If the Path Foundation model is not loaded
382
+
383
+ Note:
384
+ The embedding dimension is automatically determined from the first successful
385
+ batch and stored in self.embedding_dim for use in classifier construction.
386
+
387
+ Example:
388
+ >>> # Extract embeddings from a dataset
389
+ >>> embeddings = classifier.extract_embeddings(images, batch_size=32)
390
+ >>> print(f"Extracted {embeddings.shape[0]} embeddings of dimension {embeddings.shape[1]}")
391
+
392
+ >>> # Process with smaller batch size for memory-constrained environments
393
+ >>> embeddings = classifier.extract_embeddings(images, batch_size=8)
394
+ """
395
+ print(f"Extracting embeddings from {len(images)} images...")
396
+
397
+ embeddings = []
398
+ num_batches = (len(images) + batch_size - 1) // batch_size
399
+
400
+ for i in range(0, len(images), batch_size):
401
+ batch = images[i:i + batch_size]
402
+ processed_batch = self.preprocess_image_batch(batch)
403
+
404
+ try:
405
+ # Handle different model interface types
406
+ if hasattr(self.path_foundation, 'signatures') and "serving_default" in self.path_foundation.signatures:
407
+ # TensorFlow SavedModel format
408
+ infer = self.path_foundation.signatures["serving_default"]
409
+ batch_embeddings = infer(tf.constant(processed_batch))
410
+ elif hasattr(self.path_foundation, 'predict'):
411
+ # Standard Keras model
412
+ batch_embeddings = self.path_foundation.predict(processed_batch, verbose=0)
413
+ else:
414
+ # Direct callable
415
+ batch_embeddings = self.path_foundation(processed_batch)
416
+
417
+ # Handle different output formats
418
+ if isinstance(batch_embeddings, dict):
419
+ key = list(batch_embeddings.keys())[0]
420
+ if hasattr(batch_embeddings[key], 'numpy'):
421
+ batch_embeddings = batch_embeddings[key].numpy()
422
+ else:
423
+ batch_embeddings = batch_embeddings[key]
424
+ elif hasattr(batch_embeddings, 'numpy'):
425
+ batch_embeddings = batch_embeddings.numpy()
426
+
427
+ embeddings.append(batch_embeddings)
428
+
429
+ # Progress reporting
430
+ batch_num = i // batch_size + 1
431
+ if batch_num % 10 == 0:
432
+ print(f"Processed batch {batch_num}/{num_batches}")
433
+
434
+ except Exception as e:
435
+ print(f"Error processing batch {batch_num}: {e}")
436
+ continue
437
+
438
+ if not embeddings:
439
+ raise ValueError("No embeddings extracted successfully")
440
+
441
+ final_embeddings = np.vstack(embeddings)
442
+
443
+ # Set embedding dimension for classifier head
444
+ if self.embedding_dim is None:
445
+ self.embedding_dim = final_embeddings.shape[1]
446
+ print(f"Embedding dimension: {self.embedding_dim}")
447
+
448
+ print(f"Final embeddings shape: {final_embeddings.shape}")
449
+ return final_embeddings
450
+
451
+ def build_classifier(self):
452
+ """
453
+ Build the classification head architecture.
454
+
455
+ This method constructs the neural network architecture for breast cancer classification.
456
+ It creates different architectures based on the fine_tune setting:
457
+
458
+ 1. End-to-end model (fine_tune=True): Input -> Path Foundation -> Classifier -> Output
459
+ 2. Feature-based model (fine_tune=False): Embeddings -> Classifier -> Output
460
+
461
+ The architecture includes:
462
+ - Progressive dimensionality reduction (768 -> 384 -> 192 -> 2)
463
+ - L2 regularisation for weight decay and overfitting prevention
464
+ - Batch normalisation for training stability and faster convergence
465
+ - Dropout layers for regularization
466
+ - AdamW optimizer with appropriate learning rates
467
+
468
+ Returns:
469
+ None: The model is stored in self.model and compiled
470
+
471
+ Raises:
472
+ ValueError: If embedding dimension is not set (run extract_embeddings first)
473
+
474
+ Note:
475
+ The method automatically selects appropriate learning rates:
476
+ - Lower learning rate (1e-5) for fine-tuning to preserve pre-trained features
477
+ - Higher learning rate (0.001) for training from scratch on embeddings
478
+
479
+ Architecture Details:
480
+ - Input: Either raw images (224x224x3) or embeddings (embedding_dim,)
481
+ - Hidden layers: 768 -> 384 -> 192 neurons with ReLU activation
482
+ - Output: 2 neurons with softmax activation (benign/malignant)
483
+ - Regularisation: L2 weight decay (1e-4), Dropout (0.5, 0.3, 0.2)
484
+ - Normalisation: Batch normalisation after each dense layer
485
+
486
+ Example:
487
+ >>> classifier = BreastCancerClassifier(fine_tune=False)
488
+ >>> classifier.load_path_foundation()
489
+ >>> embeddings = classifier.extract_embeddings(images)
490
+ >>> classifier.build_classifier()
491
+ >>> print(f"Model has {classifier.model.count_params():,} parameters")
492
+ """
493
+ if self.embedding_dim is None:
494
+ raise ValueError("Embedding dimension not set - run extract_embeddings first")
495
+
496
+ if self.fine_tune:
497
+ # End-to-end fine-tuning architecture
498
+ inputs = tf.keras.Input(shape=(224, 224, 3))
499
+ x = self.path_foundation(inputs)
500
+
501
+ # Classification head with regularization
502
+ x = tf.keras.layers.Dense(768, activation='relu',
503
+ kernel_regularizer=regularizers.l2(1e-4))(x)
504
+ x = tf.keras.layers.BatchNormalization()(x)
505
+ x = tf.keras.layers.Dropout(0.5)(x)
506
+
507
+ x = tf.keras.layers.Dense(384, activation='relu',
508
+ kernel_regularizer=regularizers.l2(1e-4))(x)
509
+ x = tf.keras.layers.BatchNormalization()(x)
510
+ x = tf.keras.layers.Dropout(0.3)(x)
511
+
512
+ x = tf.keras.layers.Dense(192, activation='relu',
513
+ kernel_regularizer=regularizers.l2(1e-4))(x)
514
+ x = tf.keras.layers.Dropout(0.2)(x)
515
+
516
+ outputs = tf.keras.layers.Dense(self.num_classes, activation='softmax')(x)
517
+ self.model = tf.keras.Model(inputs, outputs)
518
+
519
+ # Lower learning rate for fine-tuning to preserve pre-trained features
520
+ optimizer = tf.keras.optimizers.AdamW(learning_rate=1e-5, weight_decay=1e-5)
521
+
522
+ else:
523
+ # Feature extraction architecture (recommended approach)
524
+ self.model = tf.keras.Sequential([
525
+ tf.keras.layers.Input(shape=(self.embedding_dim,)),
526
+
527
+ # First dense block
528
+ tf.keras.layers.Dense(768, activation='relu',
529
+ kernel_regularizer=regularizers.l2(1e-4)),
530
+ tf.keras.layers.BatchNormalization(),
531
+ tf.keras.layers.Dropout(0.5),
532
+
533
+ # Second dense block
534
+ tf.keras.layers.Dense(384, activation='relu',
535
+ kernel_regularizer=regularizers.l2(1e-4)),
536
+ tf.keras.layers.BatchNormalization(),
537
+ tf.keras.layers.Dropout(0.3),
538
+
539
+ # Third dense block
540
+ tf.keras.layers.Dense(192, activation='relu',
541
+ kernel_regularizer=regularizers.l2(1e-4)),
542
+ tf.keras.layers.Dropout(0.2),
543
+
544
+ # Output layer
545
+ tf.keras.layers.Dense(self.num_classes, activation='softmax')
546
+ ])
547
+
548
+ # Higher learning rate for training from scratch
549
+ optimizer = tf.keras.optimizers.AdamW(learning_rate=0.001, weight_decay=1e-5)
550
+
551
+ # Compile model with sparse categorical crossentropy for integer labels
552
+ self.model.compile(
553
+ optimizer=optimizer,
554
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(),
555
+ metrics=['accuracy']
556
+ )
557
+
558
+ print(f"Model architecture built - Fine-tuning: {self.fine_tune}")
559
+ print(f"Total parameters: {self.model.count_params():,}")
560
+
561
+ def train_model(self, X_train, y_train, X_val, y_val, epochs=50):
562
+ """
563
+ Train the classification model with advanced techniques and callbacks.
564
+
565
+ This method implements a comprehensive training pipeline with:
566
+ - Class balancing to handle imbalanced datasets
567
+ - Early stopping to prevent overfitting
568
+ - Learning rate reduction on plateau
569
+ - Model checkpointing to save best weights
570
+ - Adaptive batch sizing based on training mode
571
+
572
+ Args:
573
+ X_train: Training features (embeddings or images)
574
+ y_train: Training labels (0 for benign, 1 for malignant)
575
+ X_val: Validation features
576
+ y_val: Validation labels
577
+ epochs (int): Maximum number of training epochs. Default: 50
578
+
579
+ Returns:
580
+ tf.keras.callbacks.History: Training history containing loss and accuracy curves
581
+
582
+ Note:
583
+ The method automatically handles class imbalance by computing balanced weights.
584
+ Training uses different batch sizes: 32 for fine-tuning, 64 for feature extraction.
585
+
586
+ Callbacks Used:
587
+ - EarlyStopping: Stops training if validation accuracy doesn't improve for 10 epochs
588
+ - ReduceLROnPlateau: Reduces learning rate by 50% if validation loss plateaus
589
+ - ModelCheckpoint: Saves the best model based on validation accuracy
590
+
591
+ Example:
592
+ >>> # Train the model
593
+ >>> history = classifier.train_model(X_train, y_train, X_val, y_val, epochs=30)
594
+ >>>
595
+ >>> # Access training metrics
596
+ >>> print(f"Final training accuracy: {history.history['accuracy'][-1]:.4f}")
597
+ >>> print(f"Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}")
598
+ """
599
+ # Compute class weights to handle imbalanced datasets
600
+ try:
601
+ classes = np.unique(y_train)
602
+ weights = compute_class_weight(class_weight='balanced', classes=classes, y=y_train)
603
+ class_weight = {int(c): float(w) for c, w in zip(classes, weights)}
604
+ print(f"Class weights computed: {class_weight}")
605
+ except Exception:
606
+ class_weight = None
607
+ print("Using uniform class weights")
608
+
609
+ # Define training callbacks for robust training
610
+ callbacks = [
611
+ tf.keras.callbacks.EarlyStopping(
612
+ monitor='val_accuracy',
613
+ patience=10,
614
+ restore_best_weights=True,
615
+ verbose=1
616
+ ),
617
+ tf.keras.callbacks.ReduceLROnPlateau(
618
+ monitor='val_loss',
619
+ factor=0.5,
620
+ patience=5,
621
+ min_lr=1e-7,
622
+ verbose=1
623
+ ),
624
+ tf.keras.callbacks.ModelCheckpoint(
625
+ 'best_model.keras',
626
+ monitor='val_accuracy',
627
+ save_best_only=True,
628
+ verbose=0
629
+ )
630
+ ]
631
+
632
+ print("Starting model training...")
633
+ print(f"Training samples: {len(X_train)}, Validation samples: {len(X_val)}")
634
+
635
+ # Adaptive batch sizing based on training mode
636
+ batch_size = 32 if self.fine_tune else 64
637
+ print(f"Using batch size: {batch_size}")
638
+
639
+ # Train the model
640
+ self.history = self.model.fit(
641
+ X_train, y_train,
642
+ validation_data=(X_val, y_val),
643
+ epochs=epochs,
644
+ batch_size=batch_size,
645
+ callbacks=callbacks,
646
+ verbose=1,
647
+ class_weight=class_weight
648
+ )
649
+
650
+ print("Training completed successfully!")
651
+ return self.history
652
+
653
+ def evaluate_model(self, X_test, y_test):
654
+ """
655
+ Comprehensive model evaluation with multiple performance metrics and visualisations.
656
+
657
+ This method provides a thorough evaluation of the trained model including:
658
+ - Accuracy, Precision, Recall, and F1-score calculations
659
+ - Detailed classification report with per-class metrics
660
+ - Confusion matrix visualisation and analysis
661
+ - Model predictions and probabilities for further analysis
662
+
663
+ Args:
664
+ X_test: Test features (embeddings or images)
665
+ y_test: True test labels (0 for benign, 1 for malignant)
666
+
667
+ Returns:
668
+ dict: Dictionary containing comprehensive evaluation results:
669
+ - 'accuracy': Overall accuracy score
670
+ - 'precision': Weighted average precision
671
+ - 'recall': Weighted average recall
672
+ - 'f1': Weighted average F1-score
673
+ - 'predictions': Predicted class labels
674
+ - 'probabilities': Prediction probabilities for each class
675
+ - 'confusion_matrix': 2x2 confusion matrix
676
+
677
+ Note:
678
+ The method generates and saves a confusion matrix plot as 'confusion_matrix.png'
679
+ and displays it using matplotlib. The plot uses a blue color scheme for clarity.
680
+
681
+ Metrics Explanation:
682
+ - Accuracy: Overall correctness of predictions
683
+ - Precision: True positives / (True positives + False positives)
684
+ - Recall: True positives / (True positives + False negatives)
685
+ - F1-score: Harmonic mean of precision and recall
686
+
687
+ Example:
688
+ >>> # Evaluate the trained model
689
+ >>> results = classifier.evaluate_model(X_test, y_test)
690
+ >>>
691
+ >>> # Access specific metrics
692
+ >>> print(f"Test Accuracy: {results['accuracy']:.4f}")
693
+ >>> print(f"F1-Score: {results['f1']:.4f}")
694
+ >>>
695
+ >>> # Analyze predictions
696
+ >>> predictions = results['predictions']
697
+ >>> probabilities = results['probabilities']
698
+ """
699
+ print("Evaluating model performance...")
700
+
701
+ # Generate predictions and probabilities
702
+ y_pred_proba = self.model.predict(X_test)
703
+ y_pred = np.argmax(y_pred_proba, axis=1)
704
+
705
+ # Calculate comprehensive metrics
706
+ accuracy = accuracy_score(y_test, y_pred)
707
+ precision = precision_score(y_test, y_pred, average='weighted')
708
+ recall = recall_score(y_test, y_pred, average='weighted')
709
+ f1 = f1_score(y_test, y_pred, average='weighted')
710
+
711
+ # Display results
712
+ print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
713
+ print(f"Precision: {precision:.4f}")
714
+ print(f"Recall: {recall:.4f}")
715
+ print(f"F1-Score: {f1:.4f}")
716
+
717
+ # Detailed classification report
718
+ class_names = ['Benign', 'Malignant']
719
+ print("\nDetailed Classification Report:")
720
+ print(classification_report(y_test, y_pred, target_names=class_names))
721
+
722
+ # Generate and display confusion matrix
723
+ cm = confusion_matrix(y_test, y_pred)
724
+
725
+ # Create confusion matrix visualization
726
+ plt.figure(figsize=(8, 6))
727
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
728
+ xticklabels=class_names, yticklabels=class_names)
729
+ plt.title('Confusion Matrix - Breast Cancer Classification')
730
+ plt.xlabel('Predicted Label')
731
+ plt.ylabel('True Label')
732
+ plt.tight_layout()
733
+ plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
734
+ # Close the figure to free resources and avoid blocking
735
+ plt.close()
736
+
737
+ # Print confusion matrix in text format
738
+ print("\nConfusion Matrix:")
739
+ print(f" Predicted")
740
+ print(f" {class_names[0]:>8} {class_names[1]:>8}")
741
+ print(f"Actual {class_names[0]:>6} {cm[0,0]:>8} {cm[0,1]:>8}")
742
+ print(f" {class_names[1]:>6} {cm[1,0]:>8} {cm[1,1]:>8}")
743
+
744
+ return {
745
+ 'accuracy': accuracy,
746
+ 'precision': precision,
747
+ 'recall': recall,
748
+ 'f1': f1,
749
+ 'predictions': y_pred,
750
+ 'probabilities': y_pred_proba,
751
+ 'confusion_matrix': cm
752
+ }
753
+
754
+ def load_breakhis_data(data_dir="datasets/breakhis/histology_slides/breast", max_samples_per_class=2000, magnification="40X"):
755
+ """
756
+ Load and preprocess the BreakHis breast cancer histopathology dataset.
757
+
758
+ The BreakHis dataset contains microscopic images of breast tumor tissue
759
+ collected from clinical studies. Images are organized by:
760
+ - Tumor type (benign/malignant)
761
+ - Specific histological type (adenosis, fibroadenoma, etc.)
762
+ - Patient ID
763
+ - Magnification level (40X, 100X, 200X, 400X)
764
+
765
+ This function loads images from the specified magnification level and
766
+ preprocesses them for use with the Path Foundation model.
767
+
768
+ Args:
769
+ data_dir (str): Path to BreakHis dataset root directory. Default structure:
770
+ datasets/breakhis/histology_slides/breast/
771
+ max_samples_per_class (int): Maximum images to load per class (benign/malignant).
772
+ Helps manage memory usage for large datasets.
773
+ magnification (str): Magnification level to use. Options: "40X", "100X", "200X", "400X".
774
+ Higher magnifications provide more detail but larger file sizes.
775
+
776
+ Returns:
777
+ tuple: (images, labels) as numpy arrays
778
+ - images: Array of shape (num_images, 224, 224, 3) with normalized pixel values
779
+ - labels: Array of shape (num_images,) with 0 for benign, 1 for malignant
780
+
781
+ Dataset Structure:
782
+ The function expects the following directory structure:
783
+ data_dir/
784
+ ├── benign/SOB/
785
+ │ ├── adenosis/
786
+ │ ├── fibroadenoma/
787
+ │ ├── phyllodes_tumor/
788
+ │ └── tubular_adenoma/
789
+ └── malignant/SOB/
790
+ ├── ductal_carcinoma/
791
+ ├── lobular_carcinoma/
792
+ ├── mucinous_carcinoma/
793
+ └── papillary_carcinoma/
794
+
795
+ Note:
796
+ Images are automatically resized to 224x224 pixels and normalized to [0,1] range.
797
+ The function handles various image formats (PNG, JPG, JPEG, TIF, TIFF).
798
+
799
+ Example:
800
+ >>> # Load BreakHis dataset with 40X magnification
801
+ >>> images, labels = load_breakhis_data(
802
+ ... data_dir="datasets/breakhis/histology_slides/breast",
803
+ ... max_samples_per_class=1000,
804
+ ... magnification="40X"
805
+ ... )
806
+ >>> print(f"Loaded {len(images)} images")
807
+ >>> print(f"Benign: {np.sum(labels == 0)}, Malignant: {np.sum(labels == 1)}")
808
+ """
809
+ print(f"Loading BreakHis dataset (magnification: {magnification})...")
810
+
811
+ benign_dir = os.path.join(data_dir, "benign", "SOB")
812
+ malignant_dir = os.path.join(data_dir, "malignant", "SOB")
813
+
814
+ images = []
815
+ labels = []
816
+
817
+ def load_images_from_category(base_dir, label, max_count):
818
+ """
819
+ Helper function to load images from a specific category (benign/malignant).
820
+
821
+ Traverses the directory structure: base_dir/tumor_type/patient_id/magnification/images
822
+ and loads images with progress reporting.
823
+ """
824
+ if not os.path.exists(base_dir):
825
+ print(f"Warning: Directory {base_dir} not found")
826
+ return 0
827
+
828
+ count = 0
829
+
830
+ # Traverse: base_dir/tumor_type/patient_id/magnification/images
831
+ for tumor_type in os.listdir(base_dir):
832
+ tumor_dir = os.path.join(base_dir, tumor_type)
833
+ if not os.path.isdir(tumor_dir):
834
+ continue
835
+
836
+ for patient_id in os.listdir(tumor_dir):
837
+ patient_dir = os.path.join(tumor_dir, patient_id)
838
+ if not os.path.isdir(patient_dir):
839
+ continue
840
+
841
+ mag_dir = os.path.join(patient_dir, magnification)
842
+ if not os.path.exists(mag_dir):
843
+ continue
844
+
845
+ for filename in os.listdir(mag_dir):
846
+ if count >= max_count:
847
+ return count
848
+
849
+ if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff')):
850
+ try:
851
+ img_path = os.path.join(mag_dir, filename)
852
+ img = Image.open(img_path).convert('RGB')
853
+ img = img.resize((224, 224))
854
+ img_array = np.array(img).astype('float32') / 255.0
855
+ images.append(img_array)
856
+ labels.append(label)
857
+ count += 1
858
+
859
+ if count % 100 == 0:
860
+ category = 'benign' if label == 0 else 'malignant'
861
+ print(f"Loaded {count} {category} images")
862
+
863
+ except Exception as e:
864
+ print(f"Error loading {filename}: {e}")
865
+ continue
866
+ return count
867
+
868
+ # Load both categories
869
+ benign_count = load_images_from_category(benign_dir, 0, max_samples_per_class)
870
+ malignant_count = load_images_from_category(malignant_dir, 1, max_samples_per_class)
871
+
872
+ print(f"BreakHis dataset loaded: {benign_count} benign, {malignant_count} malignant images")
873
+
874
+ return np.array(images), np.array(labels)
875
+
876
+ def load_pcam_data(data_dir="datasets/pcam", label_dir="datasets/Labels", max_samples=3000, augment=True):
877
+ """
878
+ Load and preprocess the PatchCamelyon (PCam) dataset.
879
+
880
+ PCam contains 96x96 pixel patches extracted from histopathologic scans
881
+ of lymph node sections. Each patch is labeled with the presence of
882
+ metastatic tissue. This function includes data augmentation capabilities
883
+ to improve model generalization.
884
+
885
+ The dataset is stored in HDF5 format with separate files for images and labels,
886
+ and comes pre-split into training, validation, and test sets.
887
+
888
+ Args:
889
+ data_dir (str): Path to PCam image data directory containing:
890
+ - training_split.h5
891
+ - validation_split.h5
892
+ - test_split.h5
893
+ label_dir (str): Path to PCam label files directory containing:
894
+ - camelyonpatch_level_2_split_train_y.h5
895
+ - camelyonpatch_level_2_split_valid_y.h5
896
+ - camelyonpatch_level_2_split_test_y.h5
897
+ max_samples (int): Maximum total samples to load across all splits.
898
+ Distributed as: train=50%, val=25%, test=25%
899
+ augment (bool): Whether to apply data augmentation to training set.
900
+ Augmentation includes: horizontal flip, rotation, brightness adjustment
901
+
902
+ Returns:
903
+ dict: Dictionary with 'train', 'valid', 'test' keys containing (images, labels) tuples
904
+ - 'train': (train_images, train_labels) - Training data with optional augmentation
905
+ - 'valid': (val_images, val_labels) - Validation data
906
+ - 'test': (test_images, test_labels) - Test data
907
+
908
+ Dataset Details:
909
+ - Original patch size: 96x96 pixels
910
+ - Resized to: 224x224 pixels for Path Foundation compatibility
911
+ - Labels: 0 (normal tissue), 1 (metastatic tissue)
912
+ - Format: HDF5 files with 'x' key for images, 'y' key for labels
913
+
914
+ Data Augmentation (if enabled):
915
+ - Horizontal flip: 50% probability
916
+ - Rotation: Random 0°, 90°, 180°, or 270° rotation
917
+ - Brightness adjustment: 30% probability, factor between 0.9-1.1
918
+
919
+ Note:
920
+ The function automatically handles HDF5 file loading and memory management.
921
+ Images are resized from 96x96 to 224x224 pixels and normalized to [0,1] range.
922
+
923
+ Example:
924
+ >>> # Load PCam dataset with augmentation
925
+ >>> pcam_data = load_pcam_data(
926
+ ... data_dir="datasets/pcam",
927
+ ... label_dir="datasets/Labels",
928
+ ... max_samples=2000,
929
+ ... augment=True
930
+ ... )
931
+ >>>
932
+ >>> # Access training data
933
+ >>> train_images, train_labels = pcam_data['train']
934
+ >>> print(f"Training samples: {len(train_images)}")
935
+ >>> print(f"Image shape: {train_images[0].shape}")
936
+ """
937
+ print("Loading PatchCamelyon (PCam) dataset...")
938
+
939
+ # Define file paths
940
+ train_file = os.path.join(data_dir, "training_split.h5")
941
+ val_file = os.path.join(data_dir, "validation_split.h5")
942
+ test_file = os.path.join(data_dir, "test_split.h5")
943
+ train_label_file = os.path.join(label_dir, "camelyonpatch_level_2_split_train_y.h5")
944
+ val_label_file = os.path.join(label_dir, "camelyonpatch_level_2_split_valid_y.h5")
945
+ test_label_file = os.path.join(label_dir, "camelyonpatch_level_2_split_test_y.h5")
946
+
947
+ def preprocess(images):
948
+ """Resize and normalize images from 96x96 to 224x224 pixels."""
949
+ processed = []
950
+ for img in images:
951
+ im = Image.fromarray(img)
952
+ im = im.resize((224, 224)) # Resize to match Path Foundation input
953
+ arr = np.array(im).astype('float32') / 255.0
954
+ processed.append(arr)
955
+ return np.array(processed)
956
+
957
+ def safe_load(img_file, label_file, limit):
958
+ """Safely load data from HDF5 files with memory management."""
959
+ with h5py.File(img_file, 'r') as f_img, h5py.File(label_file, 'r') as f_lbl:
960
+ x = f_img['x'][:limit]
961
+ y = f_lbl['y'][:limit]
962
+ y = y.reshape(-1) # Ensure 1D label array
963
+ return x, y
964
+
965
+ # Load data splits with sample limits
966
+ train_images, train_labels = safe_load(train_file, train_label_file, max_samples//2)
967
+ val_images, val_labels = safe_load(val_file, val_label_file, max_samples//4)
968
+ test_images, test_labels = safe_load(test_file, test_label_file, max_samples//4)
969
+
970
+ # Preprocess all splits
971
+ train_images = preprocess(train_images)
972
+ val_images = preprocess(val_images)
973
+ test_images = preprocess(test_images)
974
+
975
+ # Apply data augmentation to training set
976
+ if augment:
977
+ print("Applying data augmentation to training set...")
978
+ for i in range(len(train_images)):
979
+ # Random horizontal flip
980
+ if np.random.rand() > 0.5:
981
+ train_images[i] = np.fliplr(train_images[i])
982
+
983
+ # Random rotation (0, 90, 180, 270 degrees)
984
+ k = np.random.randint(0, 4)
985
+ if k:
986
+ train_images[i] = np.rot90(train_images[i], k)
987
+
988
+ # Random brightness adjustment
989
+ if np.random.rand() > 0.7:
990
+ im = Image.fromarray((train_images[i] * 255).astype('uint8'))
991
+ brightness_factor = 0.9 + 0.2 * np.random.rand()
992
+ im = Image.fromarray(
993
+ np.clip(np.array(im, dtype=np.float32) * brightness_factor, 0, 255).astype('uint8')
994
+ )
995
+ train_images[i] = np.array(im).astype('float32') / 255.0
996
+
997
+ print(f"PCam dataset loaded - Train: {len(train_images)}, Val: {len(val_images)}, Test: {len(test_images)}")
998
+
999
+ return {
1000
+ 'train': (train_images, train_labels),
1001
+ 'valid': (val_images, val_labels),
1002
+ 'test': (test_images, test_labels)
1003
+ }
1004
+
1005
+ def load_bach_data(data_dir="datasets/BACH/ICIAR2018_BACH_Challenge/Photos", max_samples=400, augment=True):
1006
+ """
1007
+ Load and preprocess the BACH (ICIAR 2018) breast cancer histology dataset.
1008
+
1009
+ BACH contains microscopy images classified into four categories:
1010
+ - Normal tissue
1011
+ - Benign lesions
1012
+ - In situ carcinoma
1013
+ - Invasive carcinoma
1014
+
1015
+ For binary classification, this function maps:
1016
+ - Normal + Benign → Benign (label 0)
1017
+ - In situ + Invasive → Malignant (label 1)
1018
+
1019
+ Args:
1020
+ data_dir (str): Path to BACH dataset directory containing class subdirectories:
1021
+ - Normal/
1022
+ - Benign/
1023
+ - InSitu/
1024
+ - Invasive/
1025
+ max_samples (int): Maximum total samples to load across all classes.
1026
+ Distributed evenly across the 4 classes.
1027
+ augment (bool): Whether to apply data augmentation (currently not implemented
1028
+ for BACH dataset but parameter kept for consistency)
1029
+
1030
+ Returns:
1031
+ dict: Dictionary with 'train', 'valid', 'test' keys containing (images, labels) tuples
1032
+ - 'train': (train_images, train_labels) - Training data
1033
+ - 'valid': (val_images, val_labels) - Validation data
1034
+ - 'test': (test_images, test_labels) - Test data
1035
+
1036
+ Dataset Details:
1037
+ - Original categories: 4 classes (Normal, Benign, InSitu, Invasive)
1038
+ - Binary mapping: Normal(0), Benign(1) → Benign(0); InSitu(2), Invasive(3) → Malignant(1)
1039
+ - Image format: TIF, TIFF, PNG, JPG, JPEG
1040
+ - Resized to: 224x224 pixels for Path Foundation compatibility
1041
+ - Normalized to: [0, 1] range
1042
+
1043
+ Data Splitting:
1044
+ - Test set: 20% of total data
1045
+ - Training set: 60% of total data (75% of remaining after test split)
1046
+ - Validation set: 20% of total data (25% of remaining after test split)
1047
+ - Stratified splitting to maintain class distribution
1048
+
1049
+ Note:
1050
+ The function automatically handles the 4-class to binary classification mapping.
1051
+ Images are resized to 224x224 pixels and normalized to [0,1] range.
1052
+ The augment parameter is kept for API consistency but augmentation is not
1053
+ currently implemented for the BACH dataset.
1054
+
1055
+ Example:
1056
+ >>> # Load BACH dataset
1057
+ >>> bach_data = load_bach_data(
1058
+ ... data_dir="datasets/BACH/ICIAR2018_BACH_Challenge/Photos",
1059
+ ... max_samples=400,
1060
+ ... augment=True
1061
+ ... )
1062
+ >>>
1063
+ >>> # Access training data
1064
+ >>> train_images, train_labels = bach_data['train']
1065
+ >>> print(f"Training samples: {len(train_images)}")
1066
+ >>> print(f"Class distribution: Benign={np.sum(train_labels==0)}, Malignant={np.sum(train_labels==1)}")
1067
+ """
1068
+ print("Loading BACH (ICIAR 2018) dataset...")
1069
+
1070
+ # Original BACH categories mapped to binary classification
1071
+ class_dirs = {
1072
+ 'Normal': 0, # Normal tissue → Benign
1073
+ 'Benign': 1, # Benign lesions → Benign
1074
+ 'InSitu': 2, # In situ carcinoma → Malignant
1075
+ 'Invasive': 3, # Invasive carcinoma → Malignant
1076
+ }
1077
+
1078
+ images = []
1079
+ labels = []
1080
+ per_class_limit = None if not max_samples else max_samples // 4
1081
+ counters = {0: 0, 1: 0, 2: 0, 3: 0}
1082
+
1083
+ # Load images from each category
1084
+ for cls_name, cls_label in class_dirs.items():
1085
+ cls_path = os.path.join(data_dir, cls_name)
1086
+ if not os.path.isdir(cls_path):
1087
+ print(f"Warning: Directory {cls_path} not found")
1088
+ continue
1089
+
1090
+ for fname in os.listdir(cls_path):
1091
+ if per_class_limit and counters[cls_label] >= per_class_limit:
1092
+ break
1093
+ if not fname.lower().endswith((".tif", ".tiff", ".png", ".jpg", ".jpeg")):
1094
+ continue
1095
+
1096
+ fpath = os.path.join(cls_path, fname)
1097
+ try:
1098
+ im = Image.open(fpath).convert('RGB')
1099
+ im = im.resize((224, 224))
1100
+ arr = np.array(im).astype('float32') / 255.0
1101
+ images.append(arr)
1102
+ labels.append(cls_label)
1103
+ counters[cls_label] += 1
1104
+ except Exception as e:
1105
+ print(f"Error loading {fname}: {e}")
1106
+ continue
1107
+
1108
+ images = np.array(images)
1109
+ labels = np.array(labels)
1110
+
1111
+ # Convert 4-class to binary classification
1112
+ if labels.size > 0:
1113
+ # Map: Normal(0), Benign(1) → Benign(0); InSitu(2), Invasive(3) → Malignant(1)
1114
+ labels = np.where(np.isin(labels, [0, 1]), 0, 1)
1115
+
1116
+ print(f"BACH dataset loaded: {len(images)} images")
1117
+ print(f"Class distribution - Benign: {np.sum(labels == 0)}, Malignant: {np.sum(labels == 1)}")
1118
+
1119
+ # Split into train/validation/test sets
1120
+ X_temp, X_test, y_temp, y_test = train_test_split(
1121
+ images, labels, test_size=0.2,
1122
+ stratify=labels if len(set(labels)) > 1 else None,
1123
+ random_state=42
1124
+ )
1125
+ X_train, X_val, y_train, y_val = train_test_split(
1126
+ X_temp, y_temp, test_size=0.25,
1127
+ stratify=y_temp if len(set(y_temp)) > 1 else None,
1128
+ random_state=42
1129
+ )
1130
+
1131
+ return {
1132
+ 'train': (X_train, y_train),
1133
+ 'valid': (X_val, y_val),
1134
+ 'test': (X_test, y_test)
1135
+ }
1136
+
1137
+ def load_combined_data(dataset_choice="breakhis", max_samples=5000):
1138
+ """
1139
+ Unified data loading function supporting multiple datasets and combinations.
1140
+
1141
+ This function serves as the main entry point for data loading, supporting:
1142
+ - Individual datasets (BreakHis, PCam, BACH)
1143
+ - Combined dataset training for improved generalization
1144
+ - Consistent data splitting and preprocessing across all datasets
1145
+
1146
+ The combined dataset approach leverages multiple histopathology datasets to
1147
+ create a more robust and generalizable model by training on diverse data sources.
1148
+
1149
+ Args:
1150
+ dataset_choice (str): Dataset to load. Options:
1151
+ - "breakhis": BreakHis breast cancer histopathology dataset
1152
+ - "pcam": PatchCamelyon lymph node metastasis dataset
1153
+ - "bach": BACH ICIAR 2018 breast cancer histology dataset
1154
+ - "combined": Ensemble of all three datasets for robust training
1155
+ max_samples (int): Maximum total samples to load. For individual datasets,
1156
+ this is the total limit. For combined datasets, this is
1157
+ distributed across the constituent datasets.
1158
+
1159
+ Returns:
1160
+ dict: Dictionary with 'train', 'valid', 'test' keys containing (images, labels) tuples
1161
+ - 'train': (train_images, train_labels) - Training data
1162
+ - 'valid': (val_images, val_labels) - Validation data
1163
+ - 'test': (test_images, test_labels) - Test data
1164
+
1165
+ Dataset Combinations:
1166
+ When dataset_choice="combined", the function:
1167
+ 1. Loads BreakHis, PCam, and BACH datasets
1168
+ 2. Combines their training data
1169
+ 3. Shuffles the combined dataset
1170
+ 4. Splits into train/validation/test sets
1171
+ 5. Maintains class balance through stratified splitting
1172
+
1173
+ Sample Distribution (for combined datasets):
1174
+ - BreakHis: max_samples // 6 (per-class limit)
1175
+ - PCam: max_samples // 3 (total limit)
1176
+ - BACH: max_samples // 3 (total limit)
1177
+
1178
+ Data Splitting:
1179
+ - Test set: 20% of total data
1180
+ - Training set: 60% of total data (75% of remaining after test split)
1181
+ - Validation set: 20% of total data (25% of remaining after test split)
1182
+ - Stratified splitting to maintain class distribution
1183
+
1184
+ Note:
1185
+ All datasets are automatically preprocessed to 224x224 pixels and normalized
1186
+ to [0,1] range for compatibility with the Path Foundation model.
1187
+
1188
+ Example:
1189
+ >>> # Load individual dataset
1190
+ >>> data = load_combined_data("breakhis", max_samples=2000)
1191
+ >>>
1192
+ >>> # Load combined dataset for robust training
1193
+ >>> combined_data = load_combined_data("combined", max_samples=6000)
1194
+ >>>
1195
+ >>> # Access training data
1196
+ >>> train_images, train_labels = combined_data['train']
1197
+ >>> print(f"Combined training samples: {len(train_images)}")
1198
+ """
1199
+
1200
+ if dataset_choice.lower() == "breakhis":
1201
+ print("Loading BreakHis dataset only...")
1202
+ images, labels = load_breakhis_data(max_samples_per_class=max_samples//2)
1203
+
1204
+ # Split into train/validation/test
1205
+ X_temp, X_test, y_temp, y_test = train_test_split(
1206
+ images, labels, test_size=0.2, stratify=labels, random_state=42
1207
+ )
1208
+
1209
+ X_train, X_val, y_train, y_val = train_test_split(
1210
+ X_temp, y_temp, test_size=0.25, stratify=y_temp, random_state=42
1211
+ )
1212
+
1213
+ return {
1214
+ 'train': (X_train, y_train),
1215
+ 'valid': (X_val, y_val),
1216
+ 'test': (X_test, y_test)
1217
+ }
1218
+
1219
+ elif dataset_choice.lower() == "pcam":
1220
+ return load_pcam_data(max_samples=max_samples)
1221
+
1222
+ elif dataset_choice.lower() == "bach":
1223
+ return load_bach_data(max_samples=max_samples)
1224
+
1225
+ elif dataset_choice.lower() == "combined":
1226
+ print("Loading combined datasets for enhanced generalization...")
1227
+
1228
+ # Distribute samples across datasets
1229
+ if max_samples is None:
1230
+ per_bh = None
1231
+ per_pc = None
1232
+ per_ba = None
1233
+ else:
1234
+ per_dataset = max(1, max_samples // 3)
1235
+ per_bh = per_dataset // 2 # BreakHis uses per-class limit
1236
+ per_pc = per_dataset
1237
+ per_ba = per_dataset
1238
+
1239
+ # Load individual datasets
1240
+ print("Loading BreakHis component...")
1241
+ bh_images, bh_labels = load_breakhis_data(
1242
+ max_samples_per_class=per_bh if per_bh else 10**9
1243
+ )
1244
+
1245
+ print("Loading PCam component...")
1246
+ pcam = load_pcam_data(max_samples=per_pc, augment=True)
1247
+ pc_train_images, pc_train_labels = pcam["train"]
1248
+
1249
+ print("Loading BACH component...")
1250
+ bach = load_bach_data(max_samples=per_ba, augment=True)
1251
+ b_train_images, b_train_labels = bach["train"]
1252
+
1253
+ # Combine all datasets
1254
+ images = np.concatenate([bh_images, pc_train_images, b_train_images], axis=0)
1255
+ labels = np.concatenate([bh_labels, pc_train_labels, b_train_labels], axis=0)
1256
+
1257
+ print(f"Combined dataset: {len(images)} total images")
1258
+ print(f"Final distribution - Benign: {np.sum(labels == 0)}, Malignant: {np.sum(labels == 1)}")
1259
+
1260
+ # Shuffle combined data
1261
+ idx = np.arange(len(images))
1262
+ np.random.shuffle(idx)
1263
+ images, labels = images[idx], labels[idx]
1264
+
1265
+ # Split combined data
1266
+ X_temp, X_test, y_temp, y_test = train_test_split(
1267
+ images, labels, test_size=0.2,
1268
+ stratify=labels if len(set(labels)) > 1 else None,
1269
+ random_state=42
1270
+ )
1271
+ X_train, X_val, y_train, y_val = train_test_split(
1272
+ X_temp, y_temp, test_size=0.25,
1273
+ stratify=y_temp if len(set(y_temp)) > 1 else None,
1274
+ random_state=42
1275
+ )
1276
+
1277
+ return {
1278
+ 'train': (X_train, y_train),
1279
+ 'valid': (X_val, y_val),
1280
+ 'test': (X_test, y_test)
1281
+ }
1282
+
1283
+ else:
1284
+ raise ValueError(f"Unknown dataset choice: {dataset_choice}. "
1285
+ f"Choose from: 'breakhis', 'pcam', 'bach', 'combined'")
1286
+
1287
+ def main():
1288
+ """
1289
+ Execute the complete breast cancer classification pipeline.
1290
+
1291
+ This function coordinates all components of the machine learning workflow:
1292
+ 1. Environment validation and setup
1293
+ 2. Model authentication and loading
1294
+ 3. Dataset loading and preprocessing
1295
+ 4. Feature extraction using Path Foundation
1296
+ 5. Classifier training with advanced techniques
1297
+ 6. Comprehensive model evaluation
1298
+ 7. Model persistence for future use
1299
+
1300
+ The pipeline implements a robust transfer learning approach using Google's
1301
+ Path Foundation model as a feature extractor, followed by a trainable
1302
+ classification head for binary breast cancer classification.
1303
+
1304
+ Returns:
1305
+ tuple: (classifier_instance, evaluation_results) or (None, None) if failed
1306
+ - classifier_instance: Trained BreastCancerClassifier object
1307
+ - evaluation_results: Dictionary containing performance metrics and predictions
1308
+
1309
+ Configuration:
1310
+ The function uses global variables for configuration (can be modified):
1311
+ - DATASET_CHOICE: Dataset to use ("breakhis", "pcam", "bach", "combined")
1312
+ - MAX_SAMPLES: Maximum samples to load (adjust based on available memory)
1313
+ - EPOCHS: Number of training epochs (default: 50)
1314
+ - HF_TOKEN: Hugging Face authentication token (optional)
1315
+
1316
+ Pipeline Steps:
1317
+ 1. Prerequisites Check: Validates required packages and dependencies
1318
+ 2. Authentication: Authenticates with Hugging Face Hub
1319
+ 3. Model Loading: Downloads and loads Path Foundation model
1320
+ 4. Data Loading: Loads and preprocesses histopathology dataset
1321
+ 5. Feature Extraction: Extracts embeddings using frozen foundation model
1322
+ 6. Classifier Building: Constructs trainable classification head
1323
+ 7. Training: Trains classifier with callbacks and monitoring
1324
+ 8. Evaluation: Comprehensive performance assessment
1325
+ 9. Model Saving: Persists trained model for future use
1326
+
1327
+ Error Handling:
1328
+ The function includes comprehensive error handling with detailed error messages
1329
+ and stack traces to aid in debugging and troubleshooting.
1330
+
1331
+ Example:
1332
+ >>> # Run the complete pipeline
1333
+ >>> classifier, results = main()
1334
+ >>>
1335
+ >>> if results:
1336
+ ... print(f"Pipeline successful! Accuracy: {results['accuracy']:.4f}")
1337
+ ... # Use the trained classifier for inference
1338
+ ... else:
1339
+ ... print("Pipeline failed - check error messages")
1340
+
1341
+ Note:
1342
+ This function is designed to be run as a standalone script or imported
1343
+ and called from other modules. It provides a complete end-to-end
1344
+ machine learning pipeline for breast cancer classification.
1345
+ """
1346
+ print("="*60)
1347
+ print("BREAST CANCER CLASSIFICATION WITH PATH FOUNDATION")
1348
+ print("="*60)
1349
+
1350
+ # Validate prerequisites
1351
+ if not HF_AVAILABLE:
1352
+ print("ERROR: Prerequisites not met")
1353
+ print("Required installations: pip install tensorflow huggingface_hub transformers")
1354
+ return None, None
1355
+
1356
+ # Configuration parameters
1357
+ EPOCHS = 50
1358
+ HF_TOKEN = None # Set your Hugging Face token here if needed
1359
+
1360
+ # Global configuration (can be modified in notebook)
1361
+ if 'DATASET_CHOICE' not in globals():
1362
+ DATASET_CHOICE = 'combined' # Options: 'breakhis', 'pcam', 'bach', 'combined'
1363
+ if 'MAX_SAMPLES' not in globals():
1364
+ MAX_SAMPLES = 4000
1365
+
1366
+ print(f"Configuration:")
1367
+ print(f" - Epochs: {EPOCHS}")
1368
+ print(f" - Dataset: {DATASET_CHOICE}")
1369
+ print(f" - Max samples: {MAX_SAMPLES}")
1370
+ print(f" - Method: Feature extraction (frozen foundation model)")
1371
+
1372
+ try:
1373
+ # Initialize classifier in feature extraction mode
1374
+ classifier = BreastCancerClassifier(fine_tune=False)
1375
+
1376
+ print("\n" + "="*40)
1377
+ print("STEP 1: HUGGING FACE AUTHENTICATION")
1378
+ print("="*40)
1379
+ if not classifier.authenticate_huggingface(HF_TOKEN):
1380
+ raise Exception("Authentication failed - check your HF token")
1381
+
1382
+ print("\n" + "="*40)
1383
+ print("STEP 2: LOADING PATH FOUNDATION MODEL")
1384
+ print("="*40)
1385
+ if not classifier.load_path_foundation():
1386
+ raise Exception("Model loading failed - check network connection")
1387
+
1388
+ print("\n" + "="*40)
1389
+ print(f"STEP 3: LOADING {DATASET_CHOICE.upper()} DATASET")
1390
+ print("="*40)
1391
+ data = load_combined_data(DATASET_CHOICE, MAX_SAMPLES)
1392
+
1393
+ X_train, y_train = data['train']
1394
+ X_val, y_val = data['valid']
1395
+ X_test, y_test = data['test']
1396
+
1397
+ print(f"Dataset splits:")
1398
+ print(f" - Training: {len(X_train)} samples")
1399
+ print(f" - Validation: {len(X_val)} samples")
1400
+ print(f" - Test: {len(X_test)} samples")
1401
+
1402
+ print("\n" + "="*40)
1403
+ print("STEP 4: EXTRACTING FEATURE EMBEDDINGS")
1404
+ print("="*40)
1405
+ print("Extracting training embeddings...")
1406
+ X_train = classifier.extract_embeddings(X_train)
1407
+ print("Extracting validation embeddings...")
1408
+ X_val = classifier.extract_embeddings(X_val)
1409
+ print("Extracting test embeddings...")
1410
+ X_test = classifier.extract_embeddings(X_test)
1411
+
1412
+ print("\n" + "="*40)
1413
+ print("STEP 5: BUILDING CLASSIFICATION HEAD")
1414
+ print("="*40)
1415
+ classifier.num_classes = 2
1416
+ classifier.build_classifier()
1417
+
1418
+ print("\n" + "="*40)
1419
+ print("STEP 6: TRAINING CLASSIFIER")
1420
+ print("="*40)
1421
+ classifier.train_model(X_train, y_train, X_val, y_val, EPOCHS)
1422
+
1423
+ print("\n" + "="*40)
1424
+ print("STEP 7: MODEL EVALUATION")
1425
+ print("="*40)
1426
+ results = classifier.evaluate_model(X_test, y_test)
1427
+
1428
+ # Save trained model
1429
+ model_name = f"{DATASET_CHOICE}_breast_cancer_classifier.keras"
1430
+ classifier.model.save(model_name)
1431
+ print(f"\nModel saved as: {model_name}")
1432
+
1433
+ print("\n" + "="*60)
1434
+ print("PIPELINE COMPLETED SUCCESSFULLY")
1435
+ print("="*60)
1436
+ print(f"Final Performance Metrics:")
1437
+ print(f" - Accuracy: {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)")
1438
+ print(f" - F1-Score: {results['f1']:.4f}")
1439
+ print(f" - Precision: {results['precision']:.4f}")
1440
+ print(f" - Recall: {results['recall']:.4f}")
1441
+
1442
+ return classifier, results
1443
+
1444
+ except Exception as e:
1445
+ print(f"\nERROR: Pipeline failed - {e}")
1446
+ import traceback
1447
+ traceback.print_exc()
1448
+ return None, None
1449
+
1450
+ # Script execution section
1451
+ if __name__ == "__main__":
1452
+ """
1453
+ Main execution block for running the breast cancer classification pipeline.
1454
+
1455
+ This section is executed when the script is run directly (not imported).
1456
+ It provides a simple interface to run the complete machine learning pipeline
1457
+ and displays the final results.
1458
+
1459
+ Usage:
1460
+ python model2.py
1461
+
1462
+ The script will:
1463
+ 1. Initialize and run the complete pipeline
1464
+ 2. Display progress and intermediate results
1465
+ 3. Show final performance metrics
1466
+ 4. Save the trained model for future use
1467
+ """
1468
+ print("Starting Breast Cancer Classification Pipeline...")
1469
+ print("This may take several minutes depending on your hardware and dataset size.")
1470
+ print("="*60)
1471
+
1472
+ # Execute the complete pipeline
1473
+ classifier, results = main()
1474
+
1475
+ # Display final results
1476
+ if results:
1477
+ print("\n" + "="*60)
1478
+ print("🎉 PIPELINE EXECUTION SUCCESSFUL! 🎉")
1479
+ print("="*60)
1480
+ print(f"Final Accuracy: {results['accuracy']:.4f} ({results['accuracy']*100:.2f}%)")
1481
+ print(f"F1-Score: {results['f1']:.4f}")
1482
+ print(f"Precision: {results['precision']:.4f}")
1483
+ print(f"Recall: {results['recall']:.4f}")
1484
+ print("\nThe trained model has been saved and is ready for inference!")
1485
+ print("You can now use the classifier for breast cancer classification tasks.")
1486
+ else:
1487
+ print("\n" + "="*60)
1488
+ print("❌ PIPELINE EXECUTION FAILED ❌")
1489
+ print("="*60)
1490
+ print("Please check the error messages above for troubleshooting.")
1491
+ print("Common issues:")
1492
+ print("- Missing dependencies (install with: pip install tensorflow huggingface_hub transformers)")
1493
+ print("- Network connectivity issues (for downloading Path Foundation model)")
1494
+ print("- Insufficient memory (reduce MAX_SAMPLES parameter)")
1495
+ print("- Invalid dataset paths (check dataset directory structure)")
backend/requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pillow
4
+ ultralytics
5
+ tensorflow
6
+ numpy
7
+ huggingface_hub
8
+ joblib
9
+ scikit-learn
10
+ scikit-image
11
+ loguru
12
+ thop
13
+ seaborn
14
+ python-multipart
backend/yolo_colposcopy.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f31fd29763c04b071b449db7bfa09743527f30a62913f25d5a6097366c4bf3b4
3
+ size 6235498
frontend/.eslintrc.cjs ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ module.exports = {
2
+ root: true,
3
+ env: { browser: true, es2020: true },
4
+ extends: [
5
+ 'eslint:recommended',
6
+ 'plugin:@typescript-eslint/recommended',
7
+ 'plugin:react-hooks/recommended',
8
+ ],
9
+ ignorePatterns: ['dist', '.eslintrc.cjs'],
10
+ parser: '@typescript-eslint/parser',
11
+ plugins: ['react-refresh'],
12
+ rules: {
13
+ 'react-refresh/only-export-components': [
14
+ 'warn',
15
+ { allowConstantExport: true },
16
+ ],
17
+ },
18
+ }
frontend/.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logs
2
+ logs
3
+ *.log
4
+ npm-debug.log*
5
+ yarn-debug.log*
6
+ yarn-error.log*
7
+ pnpm-debug.log*
8
+ lerna-debug.log*
9
+
10
+ node_modules
11
+ dist
12
+ dist-ssr
13
+ *.local
14
+
15
+ # Editor directories and files
16
+ .vscode/*
17
+ !.vscode/extensions.json
18
+ .idea
19
+ .DS_Store
20
+ *.suo
21
+ *.ntvs*
22
+ *.njsproj
23
+ *.sln
24
+ *.sw?
frontend/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Magic Patterns - Vite Template
2
+
3
+ This code was generated by [Magic Patterns](https://magicpatterns.com) for this design: [Source Design](https://www.magicpatterns.com/c/jitk86q9nv1at6tcwj7cxr)
4
+
5
+ ## Getting Started
6
+
7
+ 1. Run `npm install`
8
+ 2. Run `npm run dev`
frontend/index.html ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>Manalife AI Pathology Assistant</title>
8
+ </head>
9
+ <body>
10
+ <div id="root"></div>
11
+ <script type="module" src="/src/index.tsx"></script>
12
+ </body>
13
+ </html>
frontend/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
frontend/package.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "magic-patterns-vite-template",
3
+ "version": "0.0.1",
4
+ "private": true,
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "npx vite",
8
+ "build": "npx vite build",
9
+ "lint": "eslint . --ext .js,.jsx,.ts,.tsx",
10
+ "preview": "npx vite preview"
11
+ },
12
+ "dependencies": {
13
+ "axios": "^1.12.2",
14
+ "lucide-react": "0.522.0",
15
+ "react": "^18.3.1",
16
+ "react-dom": "^18.3.1",
17
+ "react-router-dom": "^6.26.2"
18
+ },
19
+ "devDependencies": {
20
+ "@types/node": "^20.11.18",
21
+ "@types/react": "^18.3.1",
22
+ "@types/react-dom": "^18.3.1",
23
+ "@typescript-eslint/eslint-plugin": "^5.54.0",
24
+ "@typescript-eslint/parser": "^5.54.0",
25
+ "@vitejs/plugin-react": "^4.2.1",
26
+ "autoprefixer": "latest",
27
+ "eslint": "^8.50.0",
28
+ "eslint-plugin-react-hooks": "^4.6.0",
29
+ "eslint-plugin-react-refresh": "^0.4.1",
30
+ "postcss": "latest",
31
+ "tailwindcss": "3.4.17",
32
+ "typescript": "^5.5.4",
33
+ "vite": "^5.2.0"
34
+ }
35
+ }
frontend/postcss.config.js ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ export default {
2
+ plugins: {
3
+ tailwindcss: {},
4
+ autoprefixer: {},
5
+ },
6
+ }
frontend/public/banner.jpeg ADDED
frontend/public/black_logo.png ADDED
frontend/public/colpo/colp1.jpg ADDED
frontend/public/colpo/colp2.jpg ADDED
frontend/public/colpo/colp3.jpg ADDED
frontend/public/cyto/cyt1.jpg ADDED

Git LFS Details

  • SHA256: 5272c04ec47e122e332f3197d8248298eb40f496b51cfe4c37f3f43ec5a9ea2c
  • Pointer size: 131 Bytes
  • Size of remote file: 716 kB
frontend/public/cyto/cyt2.png ADDED

Git LFS Details

  • SHA256: 339cb5b78762ac985f76a36be2165f4e2c8c473d1028957cb81d7f3a2050276d
  • Pointer size: 131 Bytes
  • Size of remote file: 470 kB
frontend/public/cyto/cyt3.png ADDED

Git LFS Details

  • SHA256: 7533ab6eea48e4a4671c14c87334f8c387e39bf09e7995bbc960db6b04c2cba7
  • Pointer size: 132 Bytes
  • Size of remote file: 5.86 MB
frontend/public/histo/hist1.png ADDED
frontend/public/histo/hist2.png ADDED
frontend/public/histo/hist3.jpg ADDED
frontend/public/manalife_LOGO.jpg ADDED
frontend/public/white_logo.png ADDED
frontend/src/App.tsx ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useEffect } from "react";
2
+ import axios from "axios";
3
+ import { Header } from "./components/Header";
4
+ import { Sidebar } from "./components/Sidebar";
5
+ import { UploadSection } from "./components/UploadSection";
6
+ import { ResultsPanel } from "./components/ResultsPanel";
7
+ import { Footer } from "./components/Footer";
8
+ import { ProgressBar } from "./components/progressbar";
9
+
10
+ export function App() {
11
+ // ----------------------------
12
+ // State Management
13
+ // ----------------------------
14
+ const [selectedTest, setSelectedTest] = useState("cytology");
15
+ const [uploadedImage, setUploadedImage] = useState<string | null>(null);
16
+ const [selectedModel, setSelectedModel] = useState("");
17
+ const [apiResult, setApiResult] = useState<any>(null);
18
+ const [showResults, setShowResults] = useState(false);
19
+ const [currentStep, setCurrentStep] = useState(0);
20
+ const [loading, setLoading] = useState(false);
21
+
22
+ // ----------------------------
23
+ // Progress bar logic
24
+ // ----------------------------
25
+ useEffect(() => {
26
+ if (showResults) setCurrentStep(2);
27
+ else if (uploadedImage) setCurrentStep(1);
28
+ else setCurrentStep(0);
29
+ }, [uploadedImage, showResults]);
30
+
31
+ // ----------------------------
32
+ // Reset logic — new test
33
+ // ----------------------------
34
+ useEffect(() => {
35
+ setCurrentStep(0);
36
+ setShowResults(false);
37
+ setUploadedImage(null);
38
+ setSelectedModel("");
39
+ setApiResult(null);
40
+ }, [selectedTest]);
41
+
42
+ // ----------------------------
43
+ // Analyze handler (Backend call)
44
+ // ----------------------------
45
+ const handleAnalyze = async () => {
46
+ if (!uploadedImage || !selectedModel) {
47
+ alert("Please select a model and upload an image first!");
48
+ return;
49
+ }
50
+
51
+
52
+ setLoading(true);
53
+ setShowResults(false);
54
+ setApiResult(null);
55
+
56
+ try {
57
+ // Convert Base64 → File
58
+ const blob = await fetch(uploadedImage).then((r) => r.blob());
59
+ const file = new File([blob], "input.jpg", { type: blob.type });
60
+
61
+ const formData = new FormData();
62
+ formData.append("file", file);
63
+ formData.append("analysis_type", selectedTest);
64
+ formData.append("model_name", selectedModel);
65
+
66
+ // POST to backend
67
+ const baseURL =
68
+ import.meta.env.MODE === "development"
69
+ ? "http://127.0.0.1:8000"
70
+ : window.location.origin;
71
+
72
+ const res = await axios.post(`${baseURL}/predict/`, formData, {
73
+ headers: { "Content-Type": "multipart/form-data" },
74
+ });
75
+
76
+ setApiResult(res.data);
77
+ setShowResults(true);
78
+ } catch (err) {
79
+ console.error("❌ Error during inference:", err);
80
+ alert("Error analyzing the image. Check backend logs.");
81
+ } finally {
82
+ setLoading(false);
83
+ }
84
+ };
85
+ // ----------------------------
86
+ // Layout
87
+ // ----------------------------
88
+ return ( <div className="flex flex-col min-h-screen w-full bg-gray-50"> <Header /> <ProgressBar currentStep={currentStep} />
89
+
90
+
91
+ <div className="flex flex-1">
92
+ <Sidebar selectedTest={selectedTest} onTestChange={setSelectedTest} />
93
+
94
+ <main className="flex-1 p-6">
95
+ <div className="max-w-7xl mx-auto grid grid-cols-1 lg:grid-cols-2 gap-6">
96
+ {/* Upload & Model Selection */}
97
+ <UploadSection
98
+ selectedTest={selectedTest}
99
+ uploadedImage={uploadedImage}
100
+ setUploadedImage={setUploadedImage}
101
+ selectedModel={selectedModel}
102
+ setSelectedModel={setSelectedModel}
103
+ onAnalyze={handleAnalyze}
104
+ />
105
+
106
+ {/* Results Panel */}
107
+ {showResults && (
108
+ <ResultsPanel
109
+ uploadedImage={
110
+ apiResult?.annotated_image_url || uploadedImage
111
+ }
112
+ result={apiResult}
113
+ loading={loading}
114
+ />
115
+ )}
116
+ </div>
117
+ </main>
118
+ </div>
119
+
120
+ <Footer />
121
+ </div>
122
+
123
+
124
+ );
125
+ }
frontend/src/AppRouter.tsx ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from "react";
2
+ import { BrowserRouter, Routes, Route } from "react-router-dom";
3
+ import { App } from "./App";
4
+ export function AppRouter() {
5
+ return <BrowserRouter>
6
+ <Routes>
7
+ <Route path="/" element={<App />} />
8
+ </Routes>
9
+ </BrowserRouter>;
10
+ }
frontend/src/components/Footer.tsx ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from 'react';
2
+
3
+ export function Footer() {
4
+ return (
5
+ <footer
6
+ className="relative w-full text-white py-10 mt-auto bg-cover bg-center"
7
+ style={{
8
+ backgroundImage:
9
+ "url('banner.jpeg')",
10
+ }}
11
+ >
12
+ {/* Overlay for readability */}
13
+ <div className="absolute inset-0 bg-gradient-to-t from-slate-900/95 via-slate-900/70 to-transparent" />
14
+
15
+ {/* Main footer content */}
16
+ <div className="relative max-w-7xl mx-auto px-8">
17
+ <div className="flex justify-center gap-8 mb-4">
18
+ <a
19
+ href="#"
20
+ className="relative text-blue-300 hover:text-blue-100 transition-all duration-300 after:content-[''] after:absolute after:left-0 after:bottom-0 after:w-0 hover:after:w-full after:h-[1px] after:bg-blue-300 after:transition-all after:duration-300"
21
+ >
22
+ Help Center
23
+ </a>
24
+ <a
25
+ href="#"
26
+ className="relative text-blue-300 hover:text-blue-100 transition-all duration-300 after:content-[''] after:absolute after:left-0 after:bottom-0 after:w-0 hover:after:w-full after:h-[1px] after:bg-blue-300 after:transition-all after:duration-300"
27
+ >
28
+ Contact Support
29
+ </a>
30
+ </div>
31
+
32
+ <div className="text-center">
33
+ <p className="font-semibold mb-2">© 2025 Manalife. All rights reserved.</p>
34
+ <p className="text-gray-300 text-sm">
35
+ Advancing innovation in women's health and digital pathology.
36
+ </p>
37
+ </div>
38
+ </div>
39
+
40
+ {/* Logo at bottom-right corner */}
41
+ <div className="absolute bottom-4 right-8">
42
+ <img
43
+ src="/white_logo.png"
44
+ alt="Manalife Logo"
45
+ className="h-12 w-auto opacity-90 hover:opacity-100 transition-opacity duration-300"
46
+ />
47
+ </div>
48
+ </footer>
49
+ );
50
+ }
frontend/src/components/Header.tsx ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export function Header() {
2
+ return (
3
+ <div className="w-full bg-white">
4
+ {/* Banner */}
5
+ <div
6
+ className="w-full h-24 bg-cover bg-center"
7
+ style={{
8
+ backgroundImage:
9
+ "url('/banner.jpeg')",
10
+ }}
11
+ >
12
+ <div className="w-full h-full bg-gradient-to-r from-blue-900/80 to-teal-900/80 flex items-center px-8">
13
+ {/* Logo + Title */}
14
+ <div className="flex items-center gap-4">
15
+ <img
16
+ src="/white_logo.png"
17
+ alt="Manalife Logo"
18
+ className="h-16 w-auto"
19
+ />
20
+ <h1 className="text-white text-2xl font-semibold tracking-wide">
21
+ Manalife AI Pathology Assistant
22
+ </h1>
23
+ </div>
24
+ </div>
25
+ </div>
26
+
27
+ {/* Disclaimer */}
28
+ <div className="bg-blue-50 border-b border-blue-200 px-8 py-3">
29
+ <h3 className="font-semibold text-blue-900 mb-1">Public Disclaimer</h3>
30
+ <p className="text-sm text-blue-800">
31
+ Manalife AI models are research prototypes developed to advance
32
+ innovation in women's health and digital pathology. They are not
33
+ certified medical devices and are not intended for direct diagnosis or
34
+ treatment. Clinical validation and regulatory approval are required
35
+ before any medical use.
36
+ </p>
37
+ </div>
38
+ </div>
39
+ );
40
+ }
frontend/src/components/ResultsPanel.tsx ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { DownloadIcon, InfoIcon, Loader2Icon } from "lucide-react";
2
+
3
+ interface ResultsPanelProps {
4
+ uploadedImage: string | null;
5
+ result?: any;
6
+ loading?: boolean;
7
+ }
8
+
9
+ export function ResultsPanel({ uploadedImage, result, loading }: ResultsPanelProps) {
10
+ if (loading) {
11
+ return ( <div className="bg-white rounded-lg shadow-sm p-6 flex flex-col items-center justify-center"> <Loader2Icon className="w-10 h-10 text-blue-600 animate-spin mb-3" /> <p className="text-gray-600 font-medium">Analyzing image...</p> </div>
12
+ );
13
+ }
14
+
15
+ if (!result) {
16
+ return ( <div className="bg-white rounded-lg shadow-sm p-6 text-center text-gray-500">
17
+ No analysis result available yet. </div>
18
+ );
19
+ }
20
+
21
+ const {
22
+ prediction,
23
+ confidence,
24
+ probabilities,
25
+ detections,
26
+ summary,
27
+ annotated_image_url,
28
+ model_name,
29
+ analysis_type,
30
+ } = result;
31
+
32
+ const handleDownload = () => {
33
+ if (annotated_image_url) {
34
+ const link = document.createElement("a");
35
+ link.href = annotated_image_url;
36
+ link.download = "analysis_result.jpg";
37
+ link.click();
38
+ }
39
+ };
40
+
41
+ return ( <div className="bg-white rounded-lg shadow-sm p-6">
42
+ {/* Header */} <div className="flex items-center justify-between mb-6"> <div> <h2 className="text-2xl font-bold text-gray-800">
43
+ {model_name ? model_name.toUpperCase() : "Analysis Result"} </h2> <p className="text-sm text-gray-500 capitalize">
44
+ {analysis_type || "Test Type"} </p> </div>
45
+ {annotated_image_url && ( <button
46
+ onClick={handleDownload}
47
+ className="flex items-center gap-2 bg-green-600 text-white px-4 py-2 rounded-lg hover:bg-green-700 transition-colors"
48
+ > <DownloadIcon className="w-4 h-4" />
49
+ Download Image </button>
50
+ )} </div>
51
+
52
+
53
+ {/* Image */}
54
+ <div className="relative mb-6 rounded-lg overflow-hidden border border-gray-200">
55
+ <img
56
+ src={annotated_image_url || uploadedImage || "/ui.jpg"}
57
+ alt="Analysis Result"
58
+ className="w-full h-64 object-cover"
59
+ />
60
+ </div>
61
+
62
+ {/* Results Summary */}
63
+ <div className="mb-6">
64
+ {prediction && (
65
+ <h3
66
+ className={`text-3xl font-bold ${
67
+ prediction.toLowerCase().includes("malignant") ||
68
+ prediction.toLowerCase().includes("abnormal")
69
+ ? "text-red-600"
70
+ : "text-green-600"
71
+ }`}
72
+ >
73
+ {prediction}
74
+ </h3>
75
+ )}
76
+
77
+ {confidence && (
78
+ <div className="mt-2">
79
+ <div className="flex items-center justify-between mb-1">
80
+ <span className="font-semibold text-gray-900">
81
+ Confidence: {(confidence * 100).toFixed(2)}%
82
+ </span>
83
+ <InfoIcon className="w-4 h-4 text-gray-400" />
84
+ </div>
85
+ <div className="w-full h-3 bg-gray-200 rounded-full overflow-hidden">
86
+ <div
87
+ className={`h-full ${
88
+ confidence > 0.7
89
+ ? "bg-green-500"
90
+ : confidence > 0.4
91
+ ? "bg-yellow-500"
92
+ : "bg-red-500"
93
+ }`}
94
+ style={{ width: `${confidence * 100}%` }}
95
+ />
96
+ </div>
97
+ </div>
98
+ )}
99
+
100
+ {summary && (
101
+ <p className="mt-4 text-gray-700 text-sm leading-relaxed">
102
+ {summary}
103
+ </p>
104
+ )}
105
+ </div>
106
+
107
+ {/* Detections / Probabilities */}
108
+ {detections && detections.length > 0 && (
109
+ <div className="mb-6">
110
+ <h4 className="font-semibold text-gray-900 mb-3">
111
+ Detected Regions:
112
+ </h4>
113
+ <ul className="text-sm text-gray-700 list-disc list-inside space-y-1">
114
+ {detections.map((det: any, i: number) => (
115
+ <li key={i}>
116
+ {det.name || "object"} – {(det.confidence * 100).toFixed(1)}%
117
+ </li>
118
+ ))}
119
+ </ul>
120
+ </div>
121
+ )}
122
+
123
+ {probabilities && (
124
+ <div className="mb-6">
125
+ <h4 className="font-semibold text-gray-900 mb-3">
126
+ Class Probabilities:
127
+ </h4>
128
+ <pre className="bg-gray-100 rounded-lg p-3 text-sm">
129
+ {JSON.stringify(probabilities, null, 2)}
130
+ </pre>
131
+ </div>
132
+ )}
133
+
134
+ {/* Report Button */}
135
+ <button className="w-full bg-blue-600 text-white py-3 rounded-lg font-medium hover:bg-blue-700 transition-colors flex items-center justify-center gap-2">
136
+ <DownloadIcon className="w-5 h-5" />
137
+ Generate Report
138
+ </button>
139
+ </div>
140
+
141
+
142
+ );
143
+ }
frontend/src/components/Sidebar.tsx ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState } from 'react';
2
+ import { ChevronDownIcon, FileTextIcon, HelpCircleIcon } from 'lucide-react';
3
+ interface SidebarProps {
4
+ selectedTest: string;
5
+ onTestChange: (test: string) => void;
6
+ }
7
+ export function Sidebar({
8
+ selectedTest,
9
+ onTestChange
10
+ }: SidebarProps) {
11
+ const [isDropdownOpen, setIsDropdownOpen] = useState(false);
12
+ const testTypes = [{
13
+ value: 'cytology',
14
+ label: 'Cytology Analysis'
15
+ }, {
16
+ value: 'colposcopy',
17
+ label: 'Colposcopy Analysis'
18
+ }, {
19
+ value: 'histopathology',
20
+ label: 'Histopathology Analysis'
21
+ }];
22
+ return <aside className="w-64 bg-white border-r border-gray-200 p-4">
23
+ <div className="space-y-2">
24
+ {/* New Test Dropdown */}
25
+ <div className="relative">
26
+ <button onClick={() => setIsDropdownOpen(!isDropdownOpen)} className="w-full bg-blue-600 text-white rounded-lg px-4 py-3 flex items-center justify-between hover:bg-blue-700 transition-colors">
27
+ <div className="flex items-center gap-2">
28
+ <div className="w-2 h-2 bg-white rounded-full" />
29
+ <span className="font-medium">New Test</span>
30
+ </div>
31
+ <ChevronDownIcon className={`w-5 h-5 transition-transform ${isDropdownOpen ? 'rotate-180' : ''}`} />
32
+ </button>
33
+ {isDropdownOpen && <div className="absolute top-full left-0 right-0 mt-2 bg-white border border-gray-200 rounded-lg shadow-lg z-10">
34
+ {testTypes.map(test => <button key={test.value} onClick={() => {
35
+ onTestChange(test.value);
36
+ setIsDropdownOpen(false);
37
+ }} className={`w-full text-left px-4 py-3 hover:bg-gray-50 transition-colors ${selectedTest === test.value ? 'bg-blue-50 text-blue-600' : 'text-gray-700'}`}>
38
+ {test.label}
39
+ </button>)}
40
+ </div>}
41
+ </div>
42
+ {/* History */}
43
+ <button className="w-full flex items-center gap-3 px-4 py-3 text-gray-700 hover:bg-gray-50 rounded-lg transition-colors">
44
+ <FileTextIcon className="w-5 h-5" />
45
+ <span>History</span>
46
+ </button>
47
+ {/* Help */}
48
+ <button className="w-full flex items-center gap-3 px-4 py-3 text-gray-700 hover:bg-gray-50 rounded-lg transition-colors">
49
+ <HelpCircleIcon className="w-5 h-5" />
50
+ <span>Help</span>
51
+ </button>
52
+ </div>
53
+ </aside>;
54
+ }
frontend/src/components/UploadSection.tsx ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useRef } from 'react';
2
+ import { UploadIcon } from 'lucide-react';
3
+
4
+ interface UploadSectionProps {
5
+ selectedTest: string;
6
+ uploadedImage: string | null;
7
+ setUploadedImage: (image: string | null) => void;
8
+ selectedModel: string;
9
+ setSelectedModel: (model: string) => void;
10
+ onAnalyze: () => void;
11
+ }
12
+
13
+
14
+
15
+ export function UploadSection({
16
+ selectedTest,
17
+ uploadedImage,
18
+ setUploadedImage,
19
+ selectedModel,
20
+ setSelectedModel,
21
+ onAnalyze,
22
+ }: UploadSectionProps) {
23
+ const fileInputRef = useRef<HTMLInputElement>(null);
24
+
25
+ const modelOptions = {
26
+ cytology: [
27
+ { value: 'mwt', label: 'MWT' },
28
+ { value: 'yolo', label: 'YOLOv8' },
29
+ ],
30
+ colposcopy: [
31
+ { value: 'cin', label: 'Logistic-Colpo' },
32
+
33
+ ],
34
+ histopathology: [
35
+ { value: 'histopathology', label: 'Path Foundation Model' },
36
+
37
+ ],
38
+ };
39
+
40
+
41
+ const sampleImages = {
42
+ cytology: [
43
+ "/cyto/cyt1.jpg",
44
+ "/cyto/cyt2.png",
45
+ "/cyto/cyt3.png",
46
+ ],
47
+ colposcopy: [
48
+ "/colpo/colp1.jpg",
49
+ "/colpo/colp2.jpg",
50
+ "/colpo/colp3.jpg",
51
+ ],
52
+ histopathology: [
53
+ "/histo/hist1.png",
54
+ "/histo/hist2.png",
55
+ "/histo/hist3.jpg",
56
+ ],
57
+ };
58
+
59
+ const currentModels =
60
+ modelOptions[selectedTest as keyof typeof modelOptions] || [];
61
+
62
+ const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
63
+ const file = e.target.files?.[0];
64
+ if (file) {
65
+ const reader = new FileReader();
66
+ reader.onload = (event) => {
67
+ setUploadedImage(event.target?.result as string);
68
+ };
69
+ reader.readAsDataURL(file);
70
+ }
71
+ };
72
+
73
+ const handleDrop = (e: React.DragEvent) => {
74
+ e.preventDefault();
75
+ const file = e.dataTransfer.files[0];
76
+ if (file) {
77
+ const reader = new FileReader();
78
+ reader.onload = (event) => {
79
+ setUploadedImage(event.target?.result as string);
80
+ };
81
+ reader.readAsDataURL(file);
82
+ }
83
+ };
84
+
85
+ // Handle click on sample image
86
+ const handleSampleClick = (imgUrl: string) => {
87
+ setUploadedImage(imgUrl);
88
+ };
89
+
90
+ return (
91
+ <div className="bg-white rounded-lg shadow-sm p-6">
92
+ <h2 className="text-2xl font-semibold text-gray-900 mb-6">
93
+ Upload an image of a tissue sample
94
+ </h2>
95
+
96
+ {/* Upload Area */}
97
+ <div
98
+ onDrop={handleDrop}
99
+ onDragOver={(e) => e.preventDefault()}
100
+ onClick={() => fileInputRef.current?.click()}
101
+ className="border-2 border-dashed border-gray-300 rounded-lg p-8 text-center cursor-pointer hover:border-blue-400 transition-colors"
102
+ >
103
+ <input
104
+ ref={fileInputRef}
105
+ type="file"
106
+ accept="image/*"
107
+ onChange={handleFileChange}
108
+ className="hidden"
109
+ />
110
+ <div className="flex flex-col items-center">
111
+ <div className="w-16 h-16 bg-gray-100 rounded-full flex items-center justify-center mb-4">
112
+ <UploadIcon className="w-8 h-8 text-gray-400" />
113
+ </div>
114
+ {uploadedImage ? (
115
+ <>
116
+ <p className="text-green-600 font-medium mb-2">
117
+ Image uploaded successfully!
118
+ </p>
119
+ <p className="text-sm text-gray-500 mb-4">
120
+ Click to upload a different image
121
+ </p>
122
+ <div className="w-32 h-32 rounded-lg overflow-hidden border border-gray-200">
123
+ <img
124
+ src={uploadedImage}
125
+ alt="Uploaded sample"
126
+ className="w-full h-full object-cover"
127
+ />
128
+ </div>
129
+ </>
130
+ ) : (
131
+ <p className="text-gray-600">Drag and drop or click to upload</p>
132
+ )}
133
+ </div>
134
+ </div>
135
+
136
+ {/* Model Selection */}
137
+ <div className="mt-6">
138
+ <label className="block text-sm font-medium text-gray-700 mb-2">
139
+ Select Analysis Model:
140
+ </label>
141
+ <select
142
+ value={selectedModel}
143
+ onChange={(e) => setSelectedModel(e.target.value)}
144
+ className="w-full px-4 py-3 border border-gray-300 rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent"
145
+ >
146
+ <option value="">Choose a model...</option>
147
+ {currentModels.map((model) => (
148
+ <option key={model.value} value={model.value}>
149
+ {model.label}
150
+ </option>
151
+ ))}
152
+ </select>
153
+ </div>
154
+
155
+ {/* Analyze Button */}
156
+ <button
157
+ onClick={onAnalyze}
158
+ disabled={!uploadedImage || !selectedModel}
159
+ className="w-full mt-6 bg-blue-600 text-white py-3 rounded-lg font-medium hover:bg-blue-700 disabled:bg-gray-300 disabled:cursor-not-allowed transition-colors"
160
+ >
161
+ Analyze
162
+ </button>
163
+
164
+ {/* Separator */}
165
+ <hr className="my-8 border-gray-200" />
166
+
167
+ {/* Sample Images Section */}
168
+ <div>
169
+ <h3 className="text-lg font-semibold text-gray-800 mb-4">
170
+ Samples Images
171
+ </h3>
172
+ <div className="flex flex-wrap gap-4">
173
+ {(sampleImages[selectedTest as keyof typeof sampleImages] || []).map(
174
+ (img, index) => (
175
+ <div
176
+ key={index}
177
+ className={`w-20 h-20 rounded-lg border-2 cursor-pointer transition-transform hover:scale-105 hover:border-blue-500 overflow-hidden ${
178
+ uploadedImage === img ? 'border-blue-600' : 'border-gray-300'
179
+ }`}
180
+ onClick={() => handleSampleClick(img)}
181
+ >
182
+ <img
183
+ src={img}
184
+ alt={`Sample ${index + 1}`}
185
+ className="w-full h-full object-cover"
186
+ />
187
+ </div>
188
+ )
189
+ )}
190
+ </div>
191
+ </div>
192
+ </div>
193
+ );
194
+ }
frontend/src/components/progressbar.tsx ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Fragment } from "react";
2
+ import { CheckIcon, FileTextIcon } from "lucide-react";
3
+
4
+ interface ProgressBarProps {
5
+ currentStep: number;
6
+ }
7
+
8
+ export function ProgressBar({ currentStep }: ProgressBarProps) {
9
+ const steps = [
10
+ { label: "Upload", index: 0 },
11
+ { label: "Analyze", index: 1 },
12
+ { label: "Report", index: 2 },
13
+ ];
14
+
15
+ return (
16
+ <div className="bg-white px-8 py-6 border-b border-gray-200">
17
+ <div className="max-w-2xl mx-auto flex items-center justify-center">
18
+ {steps.map((step, index) => (
19
+ <Fragment key={step.label}>
20
+ <div className="flex flex-col items-center">
21
+ {/* Step circle */}
22
+ <div
23
+ className={`w-12 h-12 rounded-full flex items-center justify-center text-white font-semibold transition-all duration-300 ${
24
+ currentStep > step.index
25
+ ? "bg-gradient-to-r from-blue-800 to-teal-600"
26
+ : currentStep === step.index
27
+ ? "bg-gradient-to-r from-blue-600 to-teal-500"
28
+ : "bg-gray-300 text-gray-600"
29
+ }`}
30
+ >
31
+ {currentStep > step.index ? (
32
+ <CheckIcon className="w-6 h-6" />
33
+ ) : step.index === 2 ? (
34
+ <FileTextIcon className="w-6 h-6" />
35
+ ) : (
36
+ <span>{index + 1}</span>
37
+ )}
38
+ </div>
39
+
40
+ {/* Label */}
41
+ <span className="mt-2 text-sm font-medium text-gray-700">
42
+ {step.label}
43
+ </span>
44
+ </div>
45
+
46
+ {/* Connecting line */}
47
+ {index < steps.length - 1 && (
48
+ <div
49
+ className={`h-1 w-32 mx-4 rounded-full transition-all duration-300 ${
50
+ currentStep > step.index
51
+ ? "bg-gradient-to-r from-blue-800 to-teal-600"
52
+ : "bg-gray-300"
53
+ }`}
54
+ />
55
+ )}
56
+ </Fragment>
57
+ ))}
58
+ </div>
59
+ </div>
60
+ );
61
+ }
frontend/src/index.css ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ /* PLEASE NOTE: THESE TAILWIND IMPORTS SHOULD NEVER BE DELETED */
2
+ @import 'tailwindcss/base';
3
+ @import 'tailwindcss/components';
4
+ @import 'tailwindcss/utilities';
5
+ /* DO NOT DELETE THESE TAILWIND IMPORTS, OTHERWISE THE STYLING WILL NOT RENDER AT ALL */
frontend/src/index.tsx ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import './index.css';
2
+ import React from "react";
3
+ import { render } from "react-dom";
4
+ import { App } from "./App";
5
+ render(<App />, document.getElementById("root"));