upscale_board / constants.py
r3gm's picture
Update constants.py
0c2a6c0 verified
import re
from huggingface_hub import HfApi
from pprint import pprint
from stablepy import BUILTIN_UPSCALERS
# Initial UPSCALER_DICT_GUI
UPSCALER_DICT_GUI = {
**{bu.replace(" ", "_"): bu for bu in BUILTIN_UPSCALERS[2:]},
"RealESRNet_x4plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
"4x_foolhardy_Remacri": "https://huggingface.co/FacehugmanIII/4x_foolhardy_Remacri/resolve/main/4x_foolhardy_Remacri.pth",
"Remacri4xExtraSmoother": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/Remacri%204x%20ExtraSmoother.pth",
"Lollypop": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/lollypop.pth",
"NickelbackFS4x": "https://huggingface.co/hollowstrawberry/upscalers-backup/resolve/main/ESRGAN/NickelbackFS%204x.pth",
"Real_HAT_GAN_SRx4": "https://huggingface.co/halffried/gyre_upscalers/resolve/main/hat_ganx4/Real_HAT_GAN_SRx4.safetensors",
"HAT-L_SRx4_ImageNet-pretrain": "https://huggingface.co/halffried/gyre_upscalers/resolve/main/hat_lx4/HAT-L_SRx4_ImageNet-pretrain.safetensors",
"Real-ESRGAN-Anime-finetuning": "https://huggingface.co/danhtran2mind/Real-ESRGAN-Anime-finetuning/resolve/main/Real-ESRGAN-Anime-finetuning.pth",
"8x_NMKD-Superscale_150000_G": "https://huggingface.co/lantianhang/8x_NMKD-Superscale_150000_G/resolve/main/8x_NMKD-Superscale_150000_G.pth",
"4x_Valar_v1": "https://huggingface.co/halffried/gyre_upscalers/resolve/main/esrgan_valar_x4/4x_Valar_v1.pth",
"Ghibli_Grain": "https://huggingface.co/anonderpling/upscalers/resolve/main/ESRGAN/ghibli_grain.pth",
"Detoon4x": "https://huggingface.co/anonderpling/upscalers/resolve/main/ESRGAN/4x_detoon_225k.pth",
"2x_Text2HD_v1-RealPLKSR": "https://github.com/starinspace/StarinspaceUpscale/releases/download/Models/2x_Text2HD_v.1-RealPLKSR.pth",
"4xNomosWebPhoto_esrgan": "https://github.com/Phhofm/models/releases/download/4xNomosWebPhoto_esrgan/4xNomosWebPhoto_esrgan.pth",
"4xTextureDAT2_otf": "https://github.com/Phhofm/models/releases/download/4xTextureDAT2_otf/4xTextureDAT2_otf.pth",
"2xVHS2HD-RealPLKSR": "https://github.com/starinspace/StarinspaceUpscale/releases/download/Models/2xVHS2HD-RealPLKSR.pth",
}
def clean_filename_for_key(filename):
"""Removes the extension and replaces special characters with underscores."""
name_without_extension = filename.rsplit('.', 1)[0]
# Replace common separators and any non-alphanumeric characters (except underscore) with an underscore
cleaned_name = re.sub(r'[^a-zA-Z0-9_]+', '_', name_without_extension)
return cleaned_name
def add_upscalers_from_author(author_name, upscaler_dict, repos_to_avoid=None):
"""
Scans an author's Hugging Face repos and adds model files to the upscaler dictionary.
Args:
author_name (str): The Hugging Face username of the author.
upscaler_dict (dict): The dictionary of upscalers to update.
repos_to_avoid (list, optional): A list of repository names to skip. Defaults to None.
"""
if repos_to_avoid is None:
repos_to_avoid = []
print(f"--- Processing author: {author_name} ---")
api = HfApi()
try:
models = api.list_models(author=author_name)
all_repo_ids = [model.modelId for model in models]
# Filter out repositories that should be avoided
repo_ids = [
repo for repo in all_repo_ids
if not any(avoid_name in repo for avoid_name in repos_to_avoid)
]
filtered_count = len(all_repo_ids) - len(repo_ids)
print(f"Found {len(all_repo_ids)} repositories. Skipping {filtered_count}, processing {len(repo_ids)}.")
except Exception as e:
print(f"Could not fetch repositories for {author_name}: {e}")
return
for repo_id in repo_ids:
try:
files = api.list_repo_files(repo_id)
# Prioritize .pth files
pth_files = [f for f in files if f.endswith(".pth")]
model_files_to_add = []
if pth_files:
model_files_to_add = pth_files
else:
# If no .pth files, look for .safetensors
safetensors_files = [f for f in files if f.endswith(".safetensors")]
if safetensors_files:
model_files_to_add = safetensors_files
if not model_files_to_add:
continue
print(f"Found {len(model_files_to_add)} model(s) in {repo_id}")
for file_path in model_files_to_add:
# Get just the filename from the full path
filename = file_path.split('/')[-1]
key = clean_filename_for_key(filename)
if key in upscaler_dict:
print(f" - Skipping duplicate key: {key}")
continue
url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}"
upscaler_dict[key] = url
print(f" + Added: {key}")
except Exception as e:
print(f"Could not process repo {repo_id}: {e}")
# --- Main Execution ---
# Create a copy to avoid modifying the original dict while iterating if needed later
updated_upscaler_dict = UPSCALER_DICT_GUI.copy()
# Define the list of repositories to skip
REPOS_TO_AVOID = ["BHI_Filtering_Post", "ppo-LunarLander-v2", "dqn-SpaceInvadersNoFrameskip-v4", "dqn-BeamRiderNoFrameskip-v4"]
# Add models from Kim2091 (no repos to avoid in this case)
add_upscalers_from_author("Kim2091", updated_upscaler_dict)
# Add models from Phips, passing the list of repos to avoid
add_upscalers_from_author("Phips", updated_upscaler_dict, repos_to_avoid=REPOS_TO_AVOID)