Spaces:
Paused
Paused
| import os | |
| import requests | |
| from rich.console import Console | |
| from tqdm import tqdm | |
| import subprocess | |
| import sys | |
| try: | |
| import folder_paths | |
| except ModuleNotFoundError: | |
| import sys | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "../../..")) | |
| import folder_paths | |
| models_to_download = { | |
| "DeepBump": { | |
| "size": 25.5, | |
| "download_url": "https://github.com/HugoTini/DeepBump/raw/master/deepbump256.onnx", | |
| "destination": "deepbump", | |
| }, | |
| "Face Swap": { | |
| "size": 660, | |
| "download_url": [ | |
| "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth", | |
| "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth", | |
| "https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx", | |
| ], | |
| "destination": "insightface", | |
| }, | |
| "GFPGAN (face enhancement)": { | |
| "size": 332, | |
| "download_url": [ | |
| "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth", | |
| "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" | |
| # TODO: provide a way to selectively download models from "packs" | |
| # https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth | |
| # https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth | |
| # https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth | |
| ], | |
| "destination": "face_restore", | |
| }, | |
| "FILM: Frame Interpolation for Large Motion": { | |
| "size": 402, | |
| "download_url": [ | |
| "https://drive.google.com/drive/folders/131_--QrieM4aQbbLWrUtbO2cGbX8-war" | |
| ], | |
| "destination": "FILM", | |
| }, | |
| } | |
| console = Console() | |
| from urllib.parse import urlparse | |
| from pathlib import Path | |
| def download_model(download_url, destination): | |
| if isinstance(download_url, list): | |
| for url in download_url: | |
| download_model(url, destination) | |
| return | |
| filename = os.path.basename(urlparse(download_url).path) | |
| response = None | |
| if "drive.google.com" in download_url: | |
| try: | |
| import gdown | |
| except ImportError: | |
| print("Installing gdown") | |
| subprocess.check_call( | |
| [ | |
| sys.executable, | |
| "-m", | |
| "pip", | |
| "install", | |
| "git+https://github.com/melMass/gdown@main", | |
| ] | |
| ) | |
| import gdown | |
| if "/folders/" in download_url: | |
| # download folder | |
| try: | |
| gdown.download_folder(download_url, output=destination, resume=True) | |
| except TypeError: | |
| gdown.download_folder(download_url, output=destination) | |
| return | |
| # download from google drive | |
| gdown.download(download_url, destination, quiet=False, resume=True) | |
| return | |
| response = requests.get(download_url, stream=True) | |
| total_size = int(response.headers.get("content-length", 0)) | |
| destination_path = os.path.join(destination, filename) | |
| with open(destination_path, "wb") as file: | |
| with tqdm( | |
| total=total_size, unit="B", unit_scale=True, desc=destination_path, ncols=80 | |
| ) as progress_bar: | |
| for data in response.iter_content(chunk_size=4096): | |
| file.write(data) | |
| progress_bar.update(len(data)) | |
| console.print( | |
| f"Downloaded model from {download_url} to {destination_path}", | |
| style="bold green", | |
| ) | |
| def ask_user_for_downloads(models_to_download): | |
| console.print("Choose models to download:") | |
| choices = {} | |
| for i, model_name in enumerate(models_to_download.keys(), start=1): | |
| choices[str(i)] = model_name | |
| console.print(f"{i}. {model_name}") | |
| console.print( | |
| "Enter the numbers of the models you want to download (comma-separated):" | |
| ) | |
| user_input = console.input(">> ") | |
| selected_models = user_input.split(",") | |
| models_to_download_selected = {} | |
| for choice in selected_models: | |
| choice = choice.strip() | |
| if choice in choices: | |
| model_name = choices[choice] | |
| models_to_download_selected[model_name] = models_to_download[model_name] | |
| elif choice == "": | |
| # download all | |
| models_to_download_selected = models_to_download | |
| else: | |
| console.print(f"Invalid choice: {choice}. Skipping.") | |
| return models_to_download_selected | |
| def handle_interrupt(): | |
| console.print("Interrupted by user.", style="bold red") | |
| def main(models_to_download, skip_input=False): | |
| try: | |
| models_to_download_selected = {} | |
| def check_destination(urls, destination): | |
| if isinstance(urls, list): | |
| for url in urls: | |
| check_destination(url, destination) | |
| return | |
| filename = os.path.basename(urlparse(urls).path) | |
| destination = os.path.join(folder_paths.models_dir, destination) | |
| if not os.path.exists(destination): | |
| os.makedirs(destination) | |
| destination_path = os.path.join(destination, filename) | |
| if os.path.exists(destination_path): | |
| url_name = os.path.basename(urlparse(urls).path) | |
| console.print( | |
| f"Checkpoint '{url_name}' for {model_name} already exists in '{destination}'" | |
| ) | |
| else: | |
| model_details["destination"] = destination | |
| models_to_download_selected[model_name] = model_details | |
| for model_name, model_details in models_to_download.items(): | |
| destination = model_details["destination"] | |
| download_url = model_details["download_url"] | |
| check_destination(download_url, destination) | |
| if not models_to_download_selected: | |
| console.print("No new models to download.") | |
| return | |
| models_to_download_selected = ( | |
| ask_user_for_downloads(models_to_download_selected) | |
| if not skip_input | |
| else models_to_download_selected | |
| ) | |
| for model_name, model_details in models_to_download_selected.items(): | |
| download_url = model_details["download_url"] | |
| destination = model_details["destination"] | |
| console.print(f"Downloading {model_name}...") | |
| download_model(download_url, destination) | |
| except KeyboardInterrupt: | |
| handle_interrupt() | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-y", "--yes", action="store_true", help="skip user input") | |
| args = parser.parse_args() | |
| main(models_to_download, args.yes) | |