Spaces:
Running
Running
| from typing import Dict, Optional, Tuple, OrderedDict | |
| from transformers import CLIPTextConfig | |
| from diffusers import UNet2DConditionModel | |
| import torch | |
| from optimum.exporters.onnx.model_configs import VisionOnnxConfig, NormalizedConfig, DummyVisionInputGenerator, DummyTimestepInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqDecoderTextInputGenerator | |
| from optimum.exporters.openvino import main_export | |
| from optimum.utils.input_generators import DummyInputGenerator, DEFAULT_DUMMY_SHAPES | |
| from optimum.utils.normalized_config import NormalizedTextConfig | |
| # IMPORTANT: You need to specify some scheduler in downloaded model cache folder to avoid errors | |
| class CustomDummyTimestepInputGenerator(DummyInputGenerator): | |
| """ | |
| Generates dummy time step inputs. | |
| """ | |
| SUPPORTED_INPUT_NAMES = ( | |
| "timestep", | |
| "timestep_cond", | |
| "text_embeds", | |
| "time_ids", | |
| ) | |
| def __init__( | |
| self, | |
| task: str, | |
| normalized_config: NormalizedConfig, | |
| batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], | |
| time_cond_proj_dim: int = 256, | |
| random_batch_size_range: Optional[Tuple[int, int]] = None, | |
| **kwargs, | |
| ): | |
| self.task = task | |
| self.vocab_size = normalized_config.vocab_size | |
| self.text_encoder_projection_dim = normalized_config.text_encoder_projection_dim | |
| self.time_ids = 5 if normalized_config.requires_aesthetics_score else 6 | |
| if random_batch_size_range: | |
| low, high = random_batch_size_range | |
| self.batch_size = random.randint(low, high) | |
| else: | |
| self.batch_size = batch_size | |
| self.time_cond_proj_dim = normalized_config.get("time_cond_proj_dim", time_cond_proj_dim) | |
| def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): | |
| shape = [self.batch_size] | |
| if input_name == "timestep": | |
| return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=int_dtype) | |
| if input_name == "timestep_cond": | |
| shape.append(self.time_cond_proj_dim) | |
| return self.random_float_tensor(shape, min_value=-1.0, max_value=1.0, framework=framework, dtype=float_dtype) | |
| shape.append(self.text_encoder_projection_dim if input_name == "text_embeds" else self.time_ids) | |
| return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype) | |
| class LCMUNetOnnxConfig(VisionOnnxConfig): | |
| ATOL_FOR_VALIDATION = 1e-3 | |
| # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu | |
| # operator support, available since opset 14 | |
| DEFAULT_ONNX_OPSET = 14 | |
| NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( | |
| image_size="sample_size", | |
| num_channels="in_channels", | |
| hidden_size="cross_attention_dim", | |
| vocab_size="norm_num_groups", | |
| allow_new=True, | |
| ) | |
| DUMMY_INPUT_GENERATOR_CLASSES = ( | |
| DummyVisionInputGenerator, | |
| CustomDummyTimestepInputGenerator, | |
| DummySeq2SeqDecoderTextInputGenerator, | |
| ) | |
| def inputs(self) -> Dict[str, Dict[int, str]]: | |
| common_inputs = OrderedDict({ | |
| "sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, | |
| "timestep": {0: "steps"}, | |
| "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, | |
| "timestep_cond": {0: "batch_size"}, | |
| }) | |
| # TODO : add text_image, image and image_embeds | |
| if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": | |
| common_inputs["text_embeds"] = {0: "batch_size"} | |
| common_inputs["time_ids"] = {0: "batch_size"} | |
| return common_inputs | |
| def outputs(self) -> Dict[str, Dict[int, str]]: | |
| return { | |
| "out_sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, | |
| } | |
| def torch_to_onnx_output_map(self) -> Dict[str, str]: | |
| return { | |
| "sample": "out_sample", | |
| } | |
| def generate_dummy_inputs(self, framework: str = "pt", **kwargs): | |
| dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) | |
| dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0] | |
| if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": | |
| dummy_inputs["added_cond_kwargs"] = { | |
| "text_embeds": dummy_inputs.pop("text_embeds"), | |
| "time_ids": dummy_inputs.pop("time_ids"), | |
| } | |
| return dummy_inputs | |
| def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]: | |
| return self.inputs # Breaks order if timestep_cond involved ( so just copy original one ) | |
| model_id = "SimianLuo/LCM_Dreamshaper_v7" | |
| text_encoder_config = CLIPTextConfig.from_pretrained(model_id, subfolder = "text_encoder") | |
| unet_config = UNet2DConditionModel.from_pretrained(model_id, subfolder = "unet").config | |
| unet_config.text_encoder_projection_dim = text_encoder_config.projection_dim | |
| unet_config.requires_aesthetics_score = False | |
| custom_onnx_configs = { | |
| "unet": LCMUNetOnnxConfig(config = unet_config, task = "semantic-segmentation") | |
| } | |
| main_export(model_name_or_path = model_id, output = "./", task = "stable-diffusion", fp16 = False, int8 = False, custom_onnx_configs = custom_onnx_configs) | |