Spaces:
Runtime error
Runtime error
kxhit
commited on
Commit
·
5833474
1
Parent(s):
ed19cf4
queue 1
Browse files- 6DoF/dataset.py +1 -29
- app.py +6 -14
6DoF/dataset.py
CHANGED
|
@@ -2,40 +2,12 @@ import os
|
|
| 2 |
import math
|
| 3 |
from pathlib import Path
|
| 4 |
import torch
|
| 5 |
-
import
|
| 6 |
-
from torch.utils.data import Dataset, DataLoader
|
| 7 |
-
from torchvision import transforms
|
| 8 |
from PIL import Image
|
| 9 |
import numpy as np
|
| 10 |
-
import webdataset as wds
|
| 11 |
-
from torch.utils.data.distributed import DistributedSampler
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
import sys
|
| 14 |
|
| 15 |
-
class ObjaverseDataLoader():
|
| 16 |
-
def __init__(self, root_dir, batch_size, total_view=12, num_workers=4):
|
| 17 |
-
self.root_dir = root_dir
|
| 18 |
-
self.batch_size = batch_size
|
| 19 |
-
self.num_workers = num_workers
|
| 20 |
-
self.total_view = total_view
|
| 21 |
-
|
| 22 |
-
image_transforms = [torchvision.transforms.Resize((256, 256)),
|
| 23 |
-
transforms.ToTensor(),
|
| 24 |
-
transforms.Normalize([0.5], [0.5])]
|
| 25 |
-
self.image_transforms = torchvision.transforms.Compose(image_transforms)
|
| 26 |
-
|
| 27 |
-
def train_dataloader(self):
|
| 28 |
-
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False,
|
| 29 |
-
image_transforms=self.image_transforms)
|
| 30 |
-
# sampler = DistributedSampler(dataset)
|
| 31 |
-
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
| 32 |
-
# sampler=sampler)
|
| 33 |
-
|
| 34 |
-
def val_dataloader(self):
|
| 35 |
-
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True,
|
| 36 |
-
image_transforms=self.image_transforms)
|
| 37 |
-
sampler = DistributedSampler(dataset)
|
| 38 |
-
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
| 39 |
|
| 40 |
def get_pose(transformation):
|
| 41 |
# transformation: 4x4
|
|
|
|
| 2 |
import math
|
| 3 |
from pathlib import Path
|
| 4 |
import torch
|
| 5 |
+
from torch.utils.data import Dataset
|
|
|
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
import numpy as np
|
|
|
|
|
|
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
import sys
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def get_pose(transformation):
|
| 13 |
# transformation: 4x4
|
app.py
CHANGED
|
@@ -183,19 +183,11 @@ def run_eschernet(eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_
|
|
| 183 |
# run inference
|
| 184 |
# pipeline.to(device)
|
| 185 |
pipeline.enable_xformers_memory_efficient_attention()
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
|
| 192 |
-
output_type="numpy").images
|
| 193 |
-
elif CaPE_TYPE == "4DoF":
|
| 194 |
-
with torch.autocast("cuda"):
|
| 195 |
-
image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[pose_out, pose_in],
|
| 196 |
-
height=h, width=w, T_in=T_in, T_out=T_out,
|
| 197 |
-
guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
|
| 198 |
-
output_type="numpy").images
|
| 199 |
|
| 200 |
# save output image
|
| 201 |
output_dir = os.path.join(tmpdirname, "eschernet")
|
|
@@ -748,7 +740,7 @@ with gr.Blocks() as demo:
|
|
| 748 |
|
| 749 |
# demo.queue(max_size=10)
|
| 750 |
# demo.launch(share=True, server_name="0.0.0.0", server_port=None)
|
| 751 |
-
demo.queue(max_size=
|
| 752 |
|
| 753 |
# if __name__ == '__main__':
|
| 754 |
# main()
|
|
|
|
| 183 |
# run inference
|
| 184 |
# pipeline.to(device)
|
| 185 |
pipeline.enable_xformers_memory_efficient_attention()
|
| 186 |
+
image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
|
| 187 |
+
poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
|
| 188 |
+
height=h, width=w, T_in=T_in, T_out=T_out,
|
| 189 |
+
guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
|
| 190 |
+
output_type="numpy").images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
# save output image
|
| 193 |
output_dir = os.path.join(tmpdirname, "eschernet")
|
|
|
|
| 740 |
|
| 741 |
# demo.queue(max_size=10)
|
| 742 |
# demo.launch(share=True, server_name="0.0.0.0", server_port=None)
|
| 743 |
+
demo.queue(max_size=10).launch()
|
| 744 |
|
| 745 |
# if __name__ == '__main__':
|
| 746 |
# main()
|