Reduce memory usage¶
A barrier to using diffusion models is the large amount of memory required. To overcome this challenge, there are several memory-reducing techniques you can use to run even some of the largest models on Ascend. Some of these techniques can even be combined to further reduce memory usage.
Tip
In many cases, optimizing for memory or speed leads to improved performance in the other, so you should try to optimize for both whenever you can. This guide focuses on minimizing memory usage, but you can also learn more about how to Speed up inference.
Memory-efficient attention¶
Recent work on optimizing bandwidth in the attention block has generated huge speed-ups and reductions in memory usage. The most recent type of memory-efficient attention is Flash Attention (you can check out the original code at HazyResearch/flash-attention).
AttnProcessors
will automatically invoke flash-attention for scaled dot-product attention calculations when the MindSpore version and hardware support it; otherwise, it will perform the original calculation according to the formula.
Tip
It is important to note that we need to manually set whether to force data type conversion since the flash-attention operator in MindSpore only supports float16
and bfloat16
data-types. When the attention interface encounters data of an unsupported data type, if force_cast_dtype
is not None, the function will forcibly convert the data to force_cast_dtype
for computation and then restore it to the original data type afterward. If force_cast_dtype
is None, it will fall back to the original attention calculation using mathematical formulas.
By default, force_cast_dtype
is set to mindspore.float16
, call set_flash_attention_force_cast_dtype
on the pipeline to change it, and you can alse call enable_flash_sdp(False)
to disable flash-attention:
from mindone.diffusers import DiffusionPipeline
import mindspore as ms
pipe = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
mindspore_dtype=ms.float16,
use_safetensors=True,
)
# Optional: You can set `force_cast_dtype` for flash-attention on model-level or pipeline-level.
# Default: mindspore.float16
pipe.set_flash_attention_force_cast_dtype(force_cast_dtype=ms.bfloat16)
pipe.unet.set_flash_attention_force_cast_dtype(force_cast_dtype=None)
# Optional: You can disable flash-attention on model-level or pipeline-level:
# pipe.enable_flash_sdp(False)
# pipe.vae.enable_flash_sdp(True)
sample = pipe("a small cat")