Denoiser-Server / showcasing-few-results.py
Rajeev-86
new models added
370e492
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
'''This script relies on predownloaded models, if you do'nt have them then run the following commands to download them-
!kaggle kernels output rajeev86/training-unet-for-image-denoising
!kaggle kernels output rajeev86/training-residual-unet-for-image-denoising
!kaggle kernels output rajeev86/training-unet-with-residuals-and-cbam-layers
Note that you may require kaggle credentials for successfully downloading the models
'''
model1_path = 'models/Script_Unet.pt'
model2_path = 'models/Script_Res-Unet.pt'
model3_path = 'models/Script_Att-Unet.pt'
try:
model1 = torch.jit.load(model1_path, map_location=device)
model1.eval()
model2 = torch.jit.load(model2_path, map_location=device)
model2.eval()
model3 = torch.jit.load(model3_path, map_location=device)
model3.eval()
except Exception as e:
print(f"Error loading model: {e}")
model1, model2, model3 = None, None, None
def denoise_image_and_show(image_paths):
for image_path in image_paths:
noisy_image = Image.open(image_path).convert('RGB')
transform = transforms.Compose([transforms.ToTensor()])
noisy_tensor = transform(noisy_image).unsqueeze(0).to(device)
denoised_tensor1 = None
denoised_tensor2 = None
denoised_tensor3 = None
with torch.no_grad():
if model1:
denoised_tensor1 = model1(noisy_tensor)
if model2:
denoised_tensor2 = model2(noisy_tensor)
if model3:
denoised_tensor3 = model3(noisy_tensor)
images_to_show = [noisy_image]
titles = ['Noisy Image']
if denoised_tensor1 is not None:
denoised1_image = transforms.ToPILImage()(denoised_tensor1.squeeze(0).cpu())
images_to_show.append(denoised1_image)
titles.append('Unet model')
if denoised_tensor2 is not None:
denoised2_image = transforms.ToPILImage()(denoised_tensor2.squeeze(0).cpu())
images_to_show.append(denoised2_image)
titles.append('res Unet model')
if denoised_tensor3 is not None:
denoised3_image = transforms.ToPILImage()(denoised_tensor3.squeeze(0).cpu())
images_to_show.append(denoised3_image)
titles.append('Att model')
fig, axes = plt.subplots(1, len(images_to_show), figsize=(5 * len(images_to_show), 5))
if len(images_to_show) == 1:
axes.imshow(images_to_show[0])
axes.set_title(titles[0])
axes.axis('off')
else:
for i, img in enumerate(images_to_show):
axes[i].imshow(img)
axes[i].set_title(titles[i])
axes[i].axis('off')
plt.tight_layout()
plt.show()
image_list = [
'images/145079.jpg',
'images/258089.jpg',
'images/29030.jpg',
'images/228076.jpg'
]
denoise_image_and_show(image_list)