File size: 3,094 Bytes
370e492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

'''This script relies on predownloaded models, if you do'nt have them then run the following commands to download them-

!kaggle kernels output rajeev86/training-unet-for-image-denoising
!kaggle kernels output rajeev86/training-residual-unet-for-image-denoising
!kaggle kernels output rajeev86/training-unet-with-residuals-and-cbam-layers

Note that you may require kaggle credentials for successfully downloading the models
'''

model1_path = 'models/Script_Unet.pt'
model2_path = 'models/Script_Res-Unet.pt'
model3_path = 'models/Script_Att-Unet.pt'

try:
    model1 = torch.jit.load(model1_path, map_location=device)
    model1.eval()
    model2 = torch.jit.load(model2_path, map_location=device)
    model2.eval()
    model3 = torch.jit.load(model3_path, map_location=device)
    model3.eval()
except Exception as e:
    print(f"Error loading model: {e}")
    model1, model2, model3 = None, None, None

def denoise_image_and_show(image_paths):
    for image_path in image_paths:
        noisy_image = Image.open(image_path).convert('RGB')

        transform = transforms.Compose([transforms.ToTensor()])
        noisy_tensor = transform(noisy_image).unsqueeze(0).to(device)

        denoised_tensor1 = None
        denoised_tensor2 = None
        denoised_tensor3 = None

        with torch.no_grad():
            if model1:
                denoised_tensor1 = model1(noisy_tensor)
            if model2:
                denoised_tensor2 = model2(noisy_tensor)
            if model3:
                denoised_tensor3 = model3(noisy_tensor)

        images_to_show = [noisy_image]
        titles = ['Noisy Image']

        if denoised_tensor1 is not None:
            denoised1_image = transforms.ToPILImage()(denoised_tensor1.squeeze(0).cpu())
            images_to_show.append(denoised1_image)
            titles.append('Unet model')

        if denoised_tensor2 is not None:
            denoised2_image = transforms.ToPILImage()(denoised_tensor2.squeeze(0).cpu())
            images_to_show.append(denoised2_image)
            titles.append('res Unet model')

        if denoised_tensor3 is not None:
            denoised3_image = transforms.ToPILImage()(denoised_tensor3.squeeze(0).cpu())
            images_to_show.append(denoised3_image)
            titles.append('Att model')

        fig, axes = plt.subplots(1, len(images_to_show), figsize=(5 * len(images_to_show), 5))

        if len(images_to_show) == 1:
            axes.imshow(images_to_show[0])
            axes.set_title(titles[0])
            axes.axis('off')
        else:
            for i, img in enumerate(images_to_show):
                axes[i].imshow(img)
                axes[i].set_title(titles[i])
                axes[i].axis('off')

        plt.tight_layout()
        plt.show()

image_list = [
    'images/145079.jpg',
    'images/258089.jpg',
    'images/29030.jpg',
    'images/228076.jpg'
]
denoise_image_and_show(image_list)