Spaces:
Running
on
T4
Running
on
T4
| # -*- coding: utf-8 -*- | |
| # | |
| # @File: app.py | |
| # @Author: Haozhe Xie | |
| # @Date: 2024-03-02 16:30:00 | |
| # @Last Modified by: Haozhe Xie | |
| # @Last Modified at: 2024-10-13 15:36:50 | |
| # @Email: [email protected] | |
| import gradio as gr | |
| import logging | |
| import numpy as np | |
| import os | |
| import pickle | |
| import ssl | |
| import subprocess | |
| import sys | |
| import urllib.request | |
| from PIL import Image | |
| # Reinstall PyTorch with CUDA 11.8 (Default version is 12.1) | |
| # subprocess.call( | |
| # [ | |
| # "pip", | |
| # "install", | |
| # "torch==2.2.2", | |
| # "torchvision==0.17.2", | |
| # "--index-url", | |
| # "https://download.pytorch.org/whl/cu118", | |
| # ] | |
| # ) | |
| import torch | |
| # Create a dummy decorator for Non-ZeroGPU environments | |
| if os.environ.get("SPACES_ZERO_GPU") is not None: | |
| import spaces | |
| else: | |
| class spaces: | |
| def GPU(func): | |
| # This is a dummy wrapper that just calls the function. | |
| def wrapper(*args, **kwargs): | |
| return func(*args, **kwargs) | |
| return wrapper | |
| # Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed | |
| ssl._create_default_https_context = ssl._create_unverified_context | |
| # Import GaussianCity modules | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "gaussiancity")) | |
| def _get_output(cmd): | |
| try: | |
| return subprocess.check_output(cmd).decode("utf-8") | |
| except Exception as ex: | |
| logging.exception(ex) | |
| return None | |
| def install_cuda_toolkit(): | |
| # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run" | |
| CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run" | |
| CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) | |
| subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) | |
| subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) | |
| subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) | |
| os.environ["CUDA_HOME"] = "/usr/local/cuda" | |
| os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) | |
| os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( | |
| os.environ["CUDA_HOME"], | |
| "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], | |
| ) | |
| # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range | |
| os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" | |
| def setup_runtime_env(): | |
| logging.info("Python Version: %s" % _get_output(["python", "--version"])) | |
| logging.info("CUDA Version: %s" % _get_output(["nvcc", "--version"])) | |
| logging.info("GCC Version: %s" % _get_output(["gcc", "--version"])) | |
| logging.info("CUDA is available: %s" % torch.cuda.is_available()) | |
| logging.info("CUDA Device Capability: %s" % (torch.cuda.get_device_capability(),)) | |
| # Install Pre-compiled CUDA extensions | |
| # Ref: https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/110 | |
| ext_dir = os.path.join(os.path.dirname(__file__), "wheels") | |
| for e in os.listdir(ext_dir): | |
| logging.info("Installing Extensions from %s" % e) | |
| subprocess.call( | |
| ["pip", "install", os.path.join(ext_dir, e)], stderr=subprocess.STDOUT | |
| ) | |
| # Compile CUDA extensions | |
| # ext_dir = os.path.join(os.path.dirname(__file__), "gaussiancity", "extensions") | |
| # for e in os.listdir(ext_dir): | |
| # if os.path.isdir(os.path.join(ext_dir, e)): | |
| # subprocess.call(["pip", "install", "."], cwd=os.path.join(ext_dir, e)) | |
| logging.info("Installed Python Packages: %s" % _get_output(["pip", "list"])) | |
| def get_models(file_name): | |
| import gaussiancity.generator | |
| if not os.path.exists(file_name): | |
| urllib.request.urlretrieve( | |
| "https://huggingface.co/hzxie/gaussian-city/resolve/main/%s" % file_name, | |
| file_name, | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| ckpt = torch.load(file_name, map_location=torch.device(device), weights_only=False) | |
| model = gaussiancity.generator.Generator( | |
| ckpt["cfg"].NETWORK.GAUSSIAN, | |
| n_classes=ckpt["cfg"].DATASETS.GOOGLE_EARTH.N_CLASSES, | |
| proj_size=ckpt["cfg"].DATASETS.GOOGLE_EARTH.PROJ_SIZE, | |
| ) | |
| if torch.cuda.is_available(): | |
| model = torch.nn.DataParallel(model).cuda().eval() | |
| model.load_state_dict(ckpt["gaussian_g"], strict=False) | |
| return model | |
| def get_city_layout(): | |
| import gaussiancity.inference | |
| layout = None | |
| if os.path.exists("assets/NYC.pkl"): | |
| with open("assets/NYC.pkl", "rb") as fp: | |
| layout = pickle.load(fp) | |
| else: | |
| td_hf = np.array(Image.open("assets/NYC-HghtFld.png")).astype(np.int32) | |
| # Fix: nonzero is not supported for tensors with more than INT_MAX elements | |
| td_hf[td_hf > 500] = 500 | |
| bu_hf = np.zeros_like(td_hf) | |
| seg_map = np.array(Image.open("assets/NYC-SegMap.png").convert("P")).astype( | |
| np.int32 | |
| ) | |
| ins_map = gaussiancity.inference.get_instance_seg_map(seg_map.copy()) | |
| pts_map = gaussiancity.inference.get_point_map(seg_map) | |
| layout = { | |
| "TD_HF": td_hf, | |
| "BU_HF": bu_hf, | |
| "SEG": seg_map, | |
| "INS": ins_map, | |
| "PTS": pts_map, | |
| } | |
| with open("assets/NYC.pkl", "wb") as fp: | |
| pickle.dump(layout, fp) | |
| centers = None | |
| if os.path.exists("assets/CENTERS.pkl"): | |
| with open("assets/CENTERS.pkl", "rb") as fp: | |
| centers = pickle.load(fp) | |
| else: | |
| centers = gaussiancity.inference.get_centers(layout["INS"], layout["TD_HF"]) | |
| with open("assets/CENTERS.pkl", "wb") as fp: | |
| pickle.dump(centers, fp) | |
| layout["CTR"] = centers | |
| return layout | |
| def get_generated_city(radius, altitude, azimuth, map_center): | |
| logging.info("CUDA is available: %s" % torch.cuda.is_available()) | |
| logging.info("PyTorch is built with CUDA: %s" % torch.version.cuda) | |
| # The import must be done after CUDA extension compilation | |
| import gaussiancity.inference | |
| return gaussiancity.inference.generate_city( | |
| get_generated_city.fgm.to("cuda"), | |
| get_generated_city.bgm.to("cuda"), | |
| get_generated_city.city_layout, | |
| map_center, | |
| map_center, | |
| radius, | |
| altitude, | |
| azimuth, | |
| ) | |
| def main(debug): | |
| title = "Generative Gaussian Splatting for Unbounded 3D City Generation" | |
| with open("README.md", "r") as f: | |
| markdown = f.read() | |
| desc = markdown[markdown.rfind("---") + 3 :] | |
| with open("ARTICLE.md", "r") as f: | |
| arti = f.read() | |
| app = gr.Interface( | |
| get_generated_city, | |
| [ | |
| gr.Slider(256, 960, value=768, step=4, label="Camera Radius (m)"), | |
| gr.Slider(256, 960, value=768, step=4, label="Camera Altitude (m)"), | |
| gr.Slider(0, 360, value=210, step=5, label="Camera Azimuth (°)"), | |
| gr.Slider(1024, 7168, value=3570, step=4, label="Map Center (px)"), | |
| ], | |
| [gr.Image(type="numpy", label="Generated City")], | |
| title=title, | |
| description=desc, | |
| article=arti, | |
| flagging_mode="never", | |
| ) | |
| app.queue(api_open=False) | |
| app.launch(debug=debug) | |
| if __name__ == "__main__": | |
| logging.basicConfig( | |
| format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO | |
| ) | |
| logging.info("Environment Variables: %s" % os.environ) | |
| # if _get_output(["nvcc", "--version"]) is None: | |
| # logging.info("Installing CUDA toolkit...") | |
| # install_cuda_toolkit() | |
| # else: | |
| # logging.info("Detected CUDA: %s" % _get_output(["nvcc", "--version"])) | |
| logging.info("Compiling CUDA extensions...") | |
| setup_runtime_env() | |
| logging.info("Downloading pretrained models...") | |
| fgm = get_models("GaussianCity-Fgnd.pth") | |
| bgm = get_models("GaussianCity-Bgnd.pth") | |
| get_generated_city.fgm = fgm | |
| get_generated_city.bgm = bgm | |
| logging.info("Loading New York city layout to RAM...") | |
| city_layout = get_city_layout() | |
| get_generated_city.city_layout = city_layout | |
| logging.info("Starting the main application...") | |
| main(os.getenv("DEBUG") == "1") | |