Quickstart¶
Modular Diffusers is a framework for quickly building flexible and customizable pipelines. At the core of Modular Diffusers are ModularPipelineBlocks
that can be combined with other blocks to adapt to new workflows. The blocks are converted into a ModularPipeline
, a friendly user-facing interface developers can use.
This doc will show you how to implement a Differential Diffusion pipeline with the modular framework.
ModularPipelineBlocks¶
ModularPipelineBlocks
are definitions that specify the components, inputs, outputs, and computation logic for a single step in a pipeline. There are four types of blocks.
ModularPipelineBlocks
is the most basic block for a single step.- [
SequentialPipelineBlocks
] is a multi-block that composes other blocks linearly. The outputs of one block are the inputs to the next block. - [
LoopSequentialPipelineBlocks
] is a multi-block that runs iteratively and is designed for iterative workflows. - [
AutoPipelineBlocks
] is a collection of blocks for different workflows and it selects which block to run based on the input. It is designed to conveniently package multiple workflows into a single pipeline.
Differential Diffusion is an image-to-image workflow. Start with the IMAGE2IMAGE_BLOCKS
preset, a collection of ModularPipelineBlocks
for image-to-image generation.
from mindone.diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
IMAGE2IMAGE_BLOCKS = InsertableDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("image_encoder", StableDiffusionXLVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("decode", StableDiffusionXLDecodeStep)
])
Pipeline and block states¶
Modular Diffusers uses state to communicate data between blocks. There are two types of states.
- [
PipelineState
] is a global state that can be used to track all inputs and outputs across all blocks. - [
BlockState
] is a local view of relevant variables from [PipelineState
] for an individual block.
Customizing blocks¶
Differential Diffusion differs from standard image-to-image in its prepare_latents
and denoise
blocks. All the other blocks can be reused, but you'll need to modify these two.
Create placeholder ModularPipelineBlocks
for prepare_latents
and denoise
by copying and modifying the existing ones.
Print the denoise
block to see that it is composed of [LoopSequentialPipelineBlocks
] with three sub-blocks, before_denoiser
, denoiser
, and after_denoiser
. Only the before_denoiser
sub-block needs to be modified to prepare the latent input for the denoiser based on the change map.
denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
print(denoise_blocks)
Replace the StableDiffusionXLLoopBeforeDenoiser
sub-block with the new SDXLDiffDiffLoopBeforeDenoiser
block.
# Copy existing blocks as placeholders
class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
"""Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later"""
# ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep
class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
prepare_latents¶
The prepare_latents
block requires the following changes.
- a processor to process the change map
- a new
inputs
to accept the user-provided change map,timestep
for precomputing all the latents andnum_inference_steps
to create the mask for updating the image regions - update the computation in the
__call__
method for processing the change map and creating the masks, and storing it in the [BlockState
]
class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
]
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("generator"),
+ InputParam("diffdiff_map", required=True),
- InputParam("latent_timestep", required=True, type_hint=ms.Tensor),
+ InputParam("timesteps", type_hint=ms.Tensor),
+ InputParam("num_inference_steps", type_hint=int),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
+ OutputParam("original_latents", type_hint=ms.Tensor),
+ OutputParam("diffdiff_masks", type_hint=ms.Tensor),
]
def __call__(self, components, state: PipelineState):
# ... existing logic ...
+ # Process change map and create masks
+ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
+ thresholds = mint.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
+ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
+ block_state.original_latents = block_state.latents
denoise¶
The before_denoiser
sub-block requires the following changes.
- a new
inputs
to accept adenoising_start
parameter,original_latents
anddiffdiff_masks
from theprepare_latents
block - update the computation in the
__call__
method for applying Differential Diffusion
class SDXLDiffDiffLoopBeforeDenoiser(ModularPipelineBlocks):
@property
def description(self) -> str:
return (
"Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
)
@property
def inputs(self) -> List[str]:
return [
InputParam("latents", required=True, type_hint=ms.Tensor),
+ InputParam("denoising_start"),
+ InputParam("original_latents", type_hint=ms.Tensor),
+ InputParam("diffdiff_masks", type_hint=ms.Tensor),
]
def __call__(self, components, block_state, i, t):
+ # Apply differential diffusion logic
+ if i == 0 and block_state.denoising_start is None:
+ block_state.latents = block_state.original_latents[:1]
+ else:
+ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
+ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
# ... rest of existing logic ...
Assembling the blocks¶
You should have all the blocks you need at this point to create a ModularPipeline
.
Copy the existing IMAGE2IMAGE_BLOCKS
preset and for the set_timesteps
block, use the set_timesteps
from the TEXT2IMAGE_BLOCKS
because Differential Diffusion doesn't require a strength
parameter.
Set the prepare_latents
and denoise
blocks to the SDXLDiffDiffPrepareLatentsStep
and SDXLDiffDiffDenoiseStep
blocks you just modified.
Call [SequentialPipelineBlocks.from_blocks_dict
] on the blocks to create a SequentialPipelineBlocks
.
DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
print(dd_blocks)
ModularPipeline¶
Convert the [SequentialPipelineBlocks
] into a ModularPipeline
with the [ModularPipeline.init_pipeline
] method. This initializes the expected components to load from a modular_model_index.json
file. Explicitly load the components by calling ModularPipeline.load_default_components
.
It is a good idea to initialize the [ComponentManager
] with the pipeline to help manage the different components. Once you call load_default_components
, the components are registered to the [ComponentManager
] and can be shared between workflows. The example below uses the collection
argument to assign the components a "diffdiff"
label for better organization.
from mindone.diffusers.modular_pipelines import ComponentsManager
components = ComponentManager()
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
dd_pipeline.load_default_componenets(mindspore_dtype=ms.float16)
Adding workflows¶
Other workflows can be added to the ModularPipeline
to support additional features without rewriting the entire pipeline from scratch.
This section demonstrates how to add an IP-Adapter or ControlNet.
IP-Adapter¶
Stable Diffusion XL already has a preset IP-Adapter block that you can use and doesn't require any changes to the existing Differential Diffusion pipeline.
from mindone.diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
Use the [sub_blocks.insert
] method to insert it into the ModularPipeline
. The example below inserts the ip_adapter_block
at position 0
. Print the pipeline to see that the ip_adapter_block
is added and it requires an ip_adapter_image
. This also added two components to the pipeline, the image_encoder
and feature_extractor
.
dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
Call [~ModularPipeline.init_pipeline
] to initialize a ModularPipeline
and use load_default_components
to load the model components. Load and set the IP-Adapter to run the pipeline.
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
dd_pipeline.load_default_components(mindspore_dtype=ms.float16)
dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
dd_pipeline.loader.set_ip_adapter_scale(0.6)
dd_pipeline = dd_pipeline.to(device)
ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg")
image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
prompt = "a green pear"
negative_prompt = "blurry"
generator = np.random.default_rng(42)
image = dd_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=25,
generator=generator,
ip_adapter_image=ip_adapter_image,
diffdiff_map=mask,
image=image,
output="images"
)[0]
ControlNet¶
Stable Diffusion XL already has a preset ControlNet block that can readily be used.
from mindone.diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep
control_input_block = StableDiffusionXLAutoControlNetInputStep()
However, it requires modifying the denoise
block because that's where the ControlNet injects the control information into the UNet.
Modify the denoise
block by replacing the StableDiffusionXLLoopDenoiser
sub-block with the StableDiffusionXLControlNetLoopDenoiser
.
class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
Insert the controlnet_input
block and replace the denoise
block with the new controlnet_denoise_block
. Initialize a ModularPipeline
and load_default_components
into it.
dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
dd_pipeline.load_default_components(mindspore_dtype=ms.float16)
dd_pipeline = dd_pipeline.to(device)
control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
prompt = "a green pear"
negative_prompt = "blurry"
generator = np.random.default_rng(42)
image = dd_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=25,
generator=generator,
control_image=control_image,
controlnet_conditioning_scale=0.5,
diffdiff_map=mask,
image=image,
output="images"
)[0]
AutoPipelineBlocks¶
The Differential Diffusion, IP-Adapter, and ControlNet workflows can be bundled into a single ModularPipeline
by using [AutoPipelineBlocks
]. This allows automatically selecting which sub-blocks to run based on the inputs like control_image
or ip_adapter_image
. If none of these inputs are passed, then it defaults to the Differential Diffusion.
Use block_trigger_inputs
to only run the SDXLDiffDiffControlNetDenoiseStep
block if a control_image
input is provided. Otherwise, the SDXLDiffDiffDenoiseStep
is used.
class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
block_names = ["controlnet_denoise", "denoise"]
block_trigger_inputs = ["controlnet_cond", None]
Add the ip_adapter
and controlnet_input
blocks.
DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep
DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7)
Call [SequentialPipelineBlocks.from_blocks_dict
] to create a [SequentialPipelineBlocks
] and create a ModularPipeline
and load in the model components to run.
dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
dd_pipeline = dd_auto_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
dd_pipeline.load_default_components(mindspore_dtype=ms.float16)
Share¶
Add your ModularPipeline
to the Hub with save_pretrained
and set push_to_hub
argument to True
.
dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True)
Other users can load the ModularPipeline
with from_pretrained
.
import torch
from mindone.diffusers.modular_pipelines import ModularPipeline, ComponentsManager
components = ComponentsManager()
diffdiff_pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-diffdiff-0704", trust_remote_code=True, components_manager=components, collection="diffdiff")
diffdiff_pipeline.load_default_components(mindspore_dtype=ms.float16)