Spaces:
Build error
Build error
edits
Browse files
app.py
CHANGED
|
@@ -24,7 +24,7 @@ scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
|
|
| 24 |
|
| 25 |
device = 'cpu'
|
| 26 |
|
| 27 |
-
# Load
|
| 28 |
state = torch.load('fire.pth', map_location='cpu')
|
| 29 |
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
|
| 30 |
net_sfm = fire_network.init_network(**state['net_params']).to(device)
|
|
@@ -37,8 +37,7 @@ for name, param in net_sfm.named_parameters():
|
|
| 37 |
|
| 38 |
state2 = torch.load('fire_imagenet.pth', map_location='cpu')
|
| 39 |
state2['net_params'] = state['net_params']
|
| 40 |
-
state2['state_dict']
|
| 41 |
-
# state2['net_params'] =
|
| 42 |
net_imagenet = fire_network.init_network(**state['net_params']).to(device)
|
| 43 |
net_imagenet.load_state_dict(state2['state_dict']) #, strict=False)
|
| 44 |
|
|
@@ -51,21 +50,16 @@ transform = transforms.Compose([
|
|
| 51 |
# ---------------------------------------
|
| 52 |
|
| 53 |
# class ImgDataset(data.Dataset):
|
| 54 |
-
|
| 55 |
# def __init__(self, images, imsize):
|
| 56 |
# self.images = images
|
| 57 |
# self.imsize = imsize
|
| 58 |
# self.transform = transforms.Compose([transforms.ToTensor(), \
|
| 59 |
# transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))])
|
| 60 |
-
|
| 61 |
-
|
| 62 |
# def __getitem__(self, index):
|
| 63 |
# img = self.images[index]
|
| 64 |
# img.thumbnail((self.imsize, self.imsize), Image.Resampling.LANCZOS)
|
| 65 |
# print('after imresize:', img.size)
|
| 66 |
# return self.transform(img)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
# def __len__(self):
|
| 70 |
# return len(self.images)
|
| 71 |
|
|
|
|
| 24 |
|
| 25 |
device = 'cpu'
|
| 26 |
|
| 27 |
+
# Load nets
|
| 28 |
state = torch.load('fire.pth', map_location='cpu')
|
| 29 |
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
|
| 30 |
net_sfm = fire_network.init_network(**state['net_params']).to(device)
|
|
|
|
| 37 |
|
| 38 |
state2 = torch.load('fire_imagenet.pth', map_location='cpu')
|
| 39 |
state2['net_params'] = state['net_params']
|
| 40 |
+
state2['state_dict'] = dict(state2['state_dict'], **dim_red_params_dict);
|
|
|
|
| 41 |
net_imagenet = fire_network.init_network(**state['net_params']).to(device)
|
| 42 |
net_imagenet.load_state_dict(state2['state_dict']) #, strict=False)
|
| 43 |
|
|
|
|
| 50 |
# ---------------------------------------
|
| 51 |
|
| 52 |
# class ImgDataset(data.Dataset):
|
|
|
|
| 53 |
# def __init__(self, images, imsize):
|
| 54 |
# self.images = images
|
| 55 |
# self.imsize = imsize
|
| 56 |
# self.transform = transforms.Compose([transforms.ToTensor(), \
|
| 57 |
# transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))])
|
|
|
|
|
|
|
| 58 |
# def __getitem__(self, index):
|
| 59 |
# img = self.images[index]
|
| 60 |
# img.thumbnail((self.imsize, self.imsize), Image.Resampling.LANCZOS)
|
| 61 |
# print('after imresize:', img.size)
|
| 62 |
# return self.transform(img)
|
|
|
|
|
|
|
| 63 |
# def __len__(self):
|
| 64 |
# return len(self.images)
|
| 65 |
|