import numpy as np import cv2 import torch import scipy.sparse as sp import sys import os from zipfile import ZipFile from .plotting import plot_side_by_side_comparison sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.HybridGNet2IGSC import Hybrid hybrid = None def scipy_to_torch_sparse(scp_matrix): values = scp_matrix.data indices = np.vstack((scp_matrix.row, scp_matrix.col)) i = torch.LongTensor(indices) v = torch.FloatTensor(values) shape = scp_matrix.shape sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape)) return sparse_tensor ## Adjacency Matrix def mOrgan(N): sub = np.zeros([N, N]) for i in range(0, N): sub[i, i-1] = 1 sub[i, (i+1)%N] = 1 return sub ## Downsampling Matrix def mOrganD(N): N2 = int(np.ceil(N/2)) sub = np.zeros([N2, N]) for i in range(0, N2): if (2*i+1) == N: sub[i, 2*i] = 1 else: sub[i, 2*i] = 1/2 sub[i, 2*i+1] = 1/2 return sub def mOrganU(N): N2 = int(np.ceil(N/2)) sub = np.zeros([N, N2]) for i in range(0, N): if i % 2 == 0: sub[i, i//2] = 1 else: sub[i, i//2] = 1/2 sub[i, (i//2 + 1) % N2] = 1/2 return sub def genMatrixesLungsHeart(): RLUNG = 44 LLUNG = 50 HEART = 26 Asub1 = mOrgan(RLUNG) Asub2 = mOrgan(LLUNG) Asub3 = mOrgan(HEART) ADsub1 = mOrgan(int(np.ceil(RLUNG / 2))) ADsub2 = mOrgan(int(np.ceil(LLUNG / 2))) ADsub3 = mOrgan(int(np.ceil(HEART / 2))) Dsub1 = mOrganD(RLUNG) Dsub2 = mOrganD(LLUNG) Dsub3 = mOrganD(HEART) Usub1 = mOrganU(RLUNG) Usub2 = mOrganU(LLUNG) Usub3 = mOrganU(HEART) p1 = RLUNG p2 = p1 + LLUNG p3 = p2 + HEART p1_ = int(np.ceil(RLUNG / 2)) p2_ = p1_ + int(np.ceil(LLUNG / 2)) p3_ = p2_ + int(np.ceil(HEART / 2)) A = np.zeros([p3, p3]) A[:p1, :p1] = Asub1 A[p1:p2, p1:p2] = Asub2 A[p2:p3, p2:p3] = Asub3 AD = np.zeros([p3_, p3_]) AD[:p1_, :p1_] = ADsub1 AD[p1_:p2_, p1_:p2_] = ADsub2 AD[p2_:p3_, p2_:p3_] = ADsub3 D = np.zeros([p3_, p3]) D[:p1_, :p1] = Dsub1 D[p1_:p2_, p1:p2] = Dsub2 D[p2_:p3_, p2:p3] = Dsub3 U = np.zeros([p3, p3_]) U[:p1, :p1_] = Usub1 U[p1:p2, p1_:p2_] = Usub2 U[p2:p3, p2_:p3_] = Usub3 return A, AD, D, U def zip_files(files, output_name="complete_results.zip"): with ZipFile(output_name, "w") as zipObj: for file in files: zipObj.write(file, arcname=file.split("/")[-1]) return output_name def getMasks(landmarks, h, w): RL, LL, H = landmarks[:44], landmarks[44:94], landmarks[94:] RL_mask, LL_mask, H_mask = [np.zeros([h, w], dtype='uint8') for _ in range(3)] RL_mask = cv2.drawContours(RL_mask, [RL.reshape(-1,1,2).astype('int')], -1, 255, -1) LL_mask = cv2.drawContours(LL_mask, [LL.reshape(-1,1,2).astype('int')], -1, 255, -1) H_mask = cv2.drawContours(H_mask, [H.reshape(-1,1,2).astype('int')], -1, 255, -1) return RL_mask, LL_mask, H_mask def pad_to_square(img): h, w = img.shape[:2] if h > w: padw = h - w auxw = padw % 2 img = np.pad(img, ((0,0),(padw//2, padw//2+auxw)), 'constant') return img, (0, padw, 0, auxw) else: padh = w - h auxh = padh % 2 img = np.pad(img, ((padh//2, padh//2+auxh),(0,0)), 'constant') return img, (padh, 0, auxh, 0) def preprocess(img): img, padding = pad_to_square(img) h, w = img.shape[:2] if h != 1024 or w != 1024: img = cv2.resize(img, (1024,1024), interpolation=cv2.INTER_CUBIC) return img, (h, w, padding) def removePreprocess(output, info): h, w, padding = info padh, padw, auxh, auxw = padding if h != 1024 or w != 1024: output = output * h else: output = output * 1024 output[:,:,0] -= padw//2 output[:,:,1] -= padh//2 return output def loadModel(device): global hybrid A, AD, D, U = genMatrixesLungsHeart() N1, N2 = A.shape[0], AD.shape[0] A, AD, D, U = [sp.csc_matrix(x).tocoo() for x in [A, AD, D, U]] D_, U_ = [D.copy()], [U.copy()] A_ = [A.copy(), A.copy(), A.copy(), AD.copy(), AD.copy(), AD.copy()] config = {'n_nodes':[N1,N1,N1,N2,N2,N2], 'latents':64, 'inputsize':1024, 'filters':[2,32,32,32,16,16,16], 'skip_features':32, 'eval_sampling':True} A_t, D_t, U_t = ([scipy_to_torch_sparse(x).to(device) for x in X] for X in (A_,D_,U_)) hybrid = Hybrid(config.copy(), D_t, U_t, A_t).to(device) hybrid.load_state_dict(torch.load("weights/weights.pt", map_location=device)) hybrid.eval() return hybrid def predict_landmarks(img, n_samples=100): global hybrid img_proc, (h, w, padding) = preprocess(img) data = torch.from_numpy(img_proc).unsqueeze(0).unsqueeze(0).to(next(hybrid.parameters()).device).float() with torch.no_grad(): mu, log_var, conv6, conv5 = hybrid.encode(data) zs = [hybrid.sampling(mu, log_var) for _ in range(n_samples)] z_exp = torch.stack(zs, dim=0) conv6_exp, conv5_exp = conv6.repeat(n_samples,1,1,1), conv5.repeat(n_samples,1,1,1) output, _, _ = hybrid.decode(z_exp, conv6_exp, conv5_exp) output = output.cpu().numpy().reshape(n_samples,-1,2) output = removePreprocess(output, (h,w,padding)).astype('int') means, stds = np.mean(output,axis=0), np.std(output,axis=0) return means, stds def segment(input_img, noise_std=0.0): """ input_img: dict with keys "image" (numpy array) and optionally "mask" noise_std: standard deviation of Gaussian noise to add for robustness Returns: path to comparison figure, list of saved files """ global hybrid if hybrid is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") hybrid = loadModel(device) # Original image and corrupted version img_orig = input_img["image"].astype(np.float32) / 255.0 mask = input_img.get("mask", None) if mask is not None: mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0 mask = 1.0 - mask img_corr = np.minimum(img_orig, mask) else: img_corr = img_orig.copy() if noise_std > 0: noise = np.random.normal(0, noise_std, img_corr.shape) img_corr = np.clip(img_corr + noise, 0.0, 1.0) # Predict landmarks means_orig, stds_orig = predict_landmarks(img_orig) means_corr, stds_corr = predict_landmarks(img_corr) # Save landmarks and masks os.makedirs("tmp", exist_ok=True) RL, LL, H = means_orig[:44], means_orig[44:94], means_orig[94:] np.savetxt("tmp/RL_landmarks.txt", RL, delimiter=" ", fmt="%d") np.savetxt("tmp/LL_landmarks.txt", LL, delimiter=" ", fmt="%d") np.savetxt("tmp/H_landmarks.txt", H, delimiter=" ", fmt="%d") RL_mask, LL_mask, H_mask = getMasks(means_orig, img_orig.shape[0], img_orig.shape[1]) cv2.imwrite("tmp/RL_mask.png", RL_mask) cv2.imwrite("tmp/LL_mask.png", LL_mask) cv2.imwrite("tmp/H_mask.png", H_mask) RL_std, LL_std, H_std = stds_orig[:44], stds_orig[44:94], stds_orig[94:] np.savetxt("tmp/RL_std.txt", RL_std, delimiter=" ", fmt="%.4f") np.savetxt("tmp/LL_std.txt", LL_std, delimiter=" ", fmt="%.4f") np.savetxt("tmp/H_std.txt", H_std, delimiter=" ", fmt="%.4f") zipf = zip_files([ "tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt", "tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png", "tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt" ]) # Optional: plot side-by-side comparison fig = plot_side_by_side_comparison(img_orig, means_orig, stds_orig, img_corr, means_corr, stds_corr) output_path = "tmp/segmentation_comparison.png" fig.savefig(output_path, dpi=300) import matplotlib.pyplot as plt plt.close(fig) saved_files = [ "tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt", "tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png", "tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt", zipf ] return output_path, saved_files