Spaces:
Runtime error
Runtime error
Fix metrics to work with grayscale datasets (#9)
Browse files- metrics/metric_utils.py +6 -1
metrics/metric_utils.py
CHANGED
|
@@ -213,6 +213,8 @@ def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_l
|
|
| 213 |
# Main loop.
|
| 214 |
item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
|
| 215 |
for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
|
|
|
|
|
|
|
| 216 |
features = detector(images.to(opts.device), **detector_kwargs)
|
| 217 |
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
| 218 |
progress.update(stats.num_items)
|
|
@@ -262,7 +264,10 @@ def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel
|
|
| 262 |
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
|
| 263 |
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
|
| 264 |
images.append(run_generator(z, c))
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
| 266 |
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
| 267 |
progress.update(stats.num_items)
|
| 268 |
return stats
|
|
|
|
| 213 |
# Main loop.
|
| 214 |
item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
|
| 215 |
for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
|
| 216 |
+
if images.shape[1] == 1:
|
| 217 |
+
images = images.repeat([1, 3, 1, 1])
|
| 218 |
features = detector(images.to(opts.device), **detector_kwargs)
|
| 219 |
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
| 220 |
progress.update(stats.num_items)
|
|
|
|
| 264 |
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
|
| 265 |
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
|
| 266 |
images.append(run_generator(z, c))
|
| 267 |
+
images = torch.cat(images)
|
| 268 |
+
if images.shape[1] == 1:
|
| 269 |
+
images = images.repeat([1, 3, 1, 1])
|
| 270 |
+
features = detector(images, **detector_kwargs)
|
| 271 |
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
| 272 |
progress.update(stats.num_items)
|
| 273 |
return stats
|