Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| from gradio_rerun import Rerun | |
| import rerun as rr | |
| import rerun.blueprint as rrb | |
| from pathlib import Path | |
| import uuid | |
| from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result | |
| from mini_dust3r.model import AsymmetricCroCo3DStereo | |
| from mini_dust3r.utils.misc import ( | |
| fill_default_args, | |
| freeze_all_params, | |
| is_symmetrized, | |
| interleave, | |
| transpose_to_landscape, | |
| ) | |
| import os | |
| from mini_dust3r.model import load_model | |
| from catmlp_dpt_head import Cat_MLP_LocalFeatures_DPT_Pts3d, postprocess | |
| DEVICE = "cuda" if torch.cuda.is_available() else "CPU" | |
| # model = AsymmetricCroCo3DStereo.from_pretrained( | |
| # "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt" | |
| # ).to(DEVICE) | |
| from mini_dust3r.heads.linear_head import LinearPts3d | |
| from mini_dust3r.heads.dpt_head import create_dpt_head | |
| def head_factory(head_type, output_mode, net, has_conf=False): | |
| """" build a prediction head for the decoder | |
| """ | |
| if head_type == 'linear' and output_mode == 'pts3d': | |
| return LinearPts3d(net, has_conf) | |
| elif head_type == 'dpt' and output_mode == 'pts3d': | |
| return create_dpt_head(net, has_conf=has_conf) | |
| if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'): | |
| local_feat_dim = int(output_mode[10:]) | |
| assert net.dec_depth > 9 | |
| l2 = net.dec_depth | |
| feature_dim = 256 | |
| last_dim = feature_dim // 2 | |
| out_nchan = 3 | |
| ed = net.enc_embed_dim | |
| dd = net.dec_embed_dim | |
| return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf, | |
| num_channels=out_nchan + has_conf, | |
| feature_dim=feature_dim, | |
| last_dim=last_dim, | |
| hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2], | |
| dim_tokens=[ed, dd, dd, dd], | |
| postprocess=postprocess, | |
| depth_mode=net.depth_mode, | |
| conf_mode=net.conf_mode, | |
| head_type='regression') | |
| else: | |
| raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") | |
| class AsymmetricMASt3R(AsymmetricCroCo3DStereo): | |
| def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs): | |
| self.desc_mode = desc_mode | |
| self.two_confs = two_confs | |
| self.desc_conf_mode = desc_conf_mode | |
| super().__init__(**kwargs) | |
| def from_pretrained(cls, pretrained_model_name_or_path, **kw): | |
| if os.path.isfile(pretrained_model_name_or_path): | |
| return load_model(pretrained_model_name_or_path, device='cpu') | |
| else: | |
| return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw) | |
| def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw): | |
| assert img_size[0] % patch_size == 0 and img_size[ | |
| 1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}' | |
| self.output_mode = output_mode | |
| self.head_type = head_type | |
| self.depth_mode = depth_mode | |
| self.conf_mode = conf_mode | |
| if self.desc_conf_mode is None: | |
| self.desc_conf_mode = conf_mode | |
| # allocate heads | |
| self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) | |
| self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode)) | |
| # magic wrapper | |
| self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only) | |
| self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only) | |
| model = AsymmetricMASt3R.from_pretrained( | |
| "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").to(DEVICE) | |
| def create_blueprint(image_name_list: list[str], log_path: Path) -> rrb.Blueprint: | |
| # dont show 2d views if there are more than 4 images as to not clutter the view | |
| if len(image_name_list) > 4: | |
| blueprint = rrb.Blueprint( | |
| rrb.Horizontal( | |
| rrb.Spatial3DView(origin=f"{log_path}"), | |
| ), | |
| collapse_panels=True, | |
| ) | |
| else: | |
| blueprint = rrb.Blueprint( | |
| rrb.Horizontal( | |
| contents=[ | |
| rrb.Spatial3DView(origin=f"{log_path}"), | |
| rrb.Vertical( | |
| contents=[ | |
| rrb.Spatial2DView( | |
| origin=f"{log_path}/camera_{i}/pinhole/", | |
| contents=[ | |
| "+ $origin/**", | |
| ], | |
| ) | |
| for i in range(len(image_name_list)) | |
| ] | |
| ), | |
| ], | |
| column_shares=[3, 1], | |
| ), | |
| collapse_panels=True, | |
| ) | |
| return blueprint | |
| def predict(image_name_list: list[str] | str): | |
| # check if is list or string and if not raise error | |
| if not isinstance(image_name_list, list) and not isinstance(image_name_list, str): | |
| raise gr.Error( | |
| f"Input must be a list of strings or a string, got: {type(image_name_list)}" | |
| ) | |
| uuid_str = str(uuid.uuid4()) | |
| filename = Path(f"/tmp/gradio/{uuid_str}.rrd") | |
| rr.init(f"{uuid_str}") | |
| log_path = Path("world") | |
| if isinstance(image_name_list, str): | |
| image_name_list = [image_name_list] | |
| optimized_results: OptimizedResult = inferece_dust3r( | |
| image_dir_or_list=image_name_list, | |
| model=model, | |
| device=DEVICE, | |
| batch_size=1, | |
| ) | |
| blueprint: rrb.Blueprint = create_blueprint(image_name_list, log_path) | |
| rr.send_blueprint(blueprint) | |
| rr.set_time_sequence("sequence", 0) | |
| log_optimized_result(optimized_results, log_path) | |
| rr.save(filename.as_posix()) | |
| return filename.as_posix() | |
| with gr.Blocks( | |
| css=""".gradio-container {margin: 0 !important; min-width: 100%};""", | |
| title="Mini-DUSt3R Demo", | |
| ) as demo: | |
| # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference | |
| gr.HTML('<h2 style="text-align: center;">Mini-DUSt3R Demo</h2>') | |
| gr.HTML( | |
| '<p style="text-align: center;">Unofficial DUSt3R demo using the mini-dust3r pip package</p>' | |
| ) | |
| gr.HTML( | |
| '<p style="text-align: center;">More info <a href="https://github.com/pablovela5620/mini-dust3r">here</a></p>' | |
| ) | |
| with gr.Tab(label="Single Image"): | |
| with gr.Column(): | |
| single_image = gr.Image(type="filepath", height=300) | |
| run_btn_single = gr.Button("Run") | |
| rerun_viewer_single = Rerun(height=900) | |
| run_btn_single.click( | |
| fn=predict, inputs=[single_image], outputs=[rerun_viewer_single] | |
| ) | |
| example_single_dir = Path("examples/single_image") | |
| example_single_files = sorted(example_single_dir.glob("*.png")) | |
| examples_single = gr.Examples( | |
| examples=example_single_files, | |
| inputs=[single_image], | |
| outputs=[rerun_viewer_single], | |
| fn=predict, | |
| cache_examples="lazy", | |
| ) | |
| with gr.Tab(label="Multi Image"): | |
| with gr.Column(): | |
| multi_files = gr.File(file_count="multiple") | |
| run_btn_multi = gr.Button("Run") | |
| rerun_viewer_multi = Rerun(height=900) | |
| run_btn_multi.click( | |
| fn=predict, inputs=[multi_files], outputs=[rerun_viewer_multi] | |
| ) | |
| demo.launch() |