Spaces:
Running
Running
| from .imagenhub_models import load_imagenhub_model | |
| from .playground_api import load_playground_model | |
| IMAGE_GENERATION_MODELS = ['imagenhub_LCM_generation','imagenhub_SDXLTurbo_generation','imagenhub_SDXL_generation', 'imagenhub_PixArtAlpha_generation', | |
| 'imagenhub_OpenJourney_generation','imagenhub_SDXLLightning_generation', 'imagenhub_StableCascade_generation', | |
| 'playground_PlayGroundV2_generation', 'playground_PlayGroundV2.5_generation'] | |
| IMAGE_EDITION_MODELS = ['imagenhub_CycleDiffusion_edition', 'imagenhub_Pix2PixZero_edition', 'imagenhub_Prompt2prompt_edition', | |
| 'imagenhub_SDEdit_edition', 'imagenhub_InstructPix2Pix_edition', 'imagenhub_MagicBrush_edition', 'imagenhub_PNP_edition'] | |
| def load_pipeline(model_name): | |
| """ | |
| Load a model pipeline based on the model name | |
| Args: | |
| model_name (str): The name of the model to load, should be of the form {source}_{name}_{type} | |
| the source can be either imagenhub or playground | |
| the name is the name of the model used to load the model | |
| the type is the type of the model, either generation or edition | |
| """ | |
| model_source, model_name, model_type = model_name.split("_") | |
| if model_source == "imagenhub": | |
| pipe = load_imagenhub_model(model_name, model_type) | |
| elif model_source == "playground": | |
| pipe = load_playground_model(model_name) | |
| else: | |
| raise ValueError(f"Model source {model_source} not supported") | |
| return pipe |