File size: 1,406 Bytes
3ec871d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from diffusers import FluxTransformer2DModel
from diffusers.modular_pipelines import ModularPipelineBlocks, ComponentSpec, InputParam, PipelineState, OutputParam
from typing import List
class DummyCustomBlockSimple(ModularPipelineBlocks):
def __init__(self, use_dummy_model_component=False):
self.use_dummy_model_component = use_dummy_model_component
super().__init__()
@property
def expected_components(self):
if self.use_dummy_model_component:
return [ComponentSpec("transformer", FluxTransformer2DModel)]
else:
return []
@property
def inputs(self) -> List[InputParam]:
return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"output_prompt",
type_hint=str,
description="Modified prompt",
)
]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
old_prompt = block_state.prompt
block_state.output_prompt = "Modular diffusers + " + old_prompt
self.set_block_state(state, block_state)
return components, state |